registry.py 38.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
55
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
56
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
57
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
58
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
59
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
60
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
62
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
63
64
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
65
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
66
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
67
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
68
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
69
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
70
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
71
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
72
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
73
74
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
75
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
76
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
77
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
78
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
79
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
80
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
81
82
83
84
85
86
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
87
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
88
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
89
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
90
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
91
92
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
93
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
94
95
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
96
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
97
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
98
99
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
100
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
101
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
102
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
103
104
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
105
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
106
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
107
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
108
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
109
110
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
111
112
113
114
115
116
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
117
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
118
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
119
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
120
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
121
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
122
123
124
125
126
127
128
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
129
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
130
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
131
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
132
133
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
134
135
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
136
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
137
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
138
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
139
140
141
142
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
143
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
144
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
145
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
147
148
149
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
150
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
151
152
153
}

_EMBEDDING_MODELS = {
154
    # [Text-only]
155
    "BertModel": ("bert", "BertEmbeddingModel"),
156
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
157
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
158
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
159
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
160
    "GritLM": ("gritlm", "GritLM"),
161
162
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
163
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
164
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
165
    "LlamaModel": ("llama", "LlamaForCausalLM"),
166
167
168
169
170
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
171
    "MistralModel": ("llama", "LlamaForCausalLM"),
172
    "ModernBertModel": ("modernbert", "ModernBertModel"),
173
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
174
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
175
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
176
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
177
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
178
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
179
180
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
181
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
182
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
183
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
184
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
185
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
186
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
187
188
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
189
    # models for the time being.
190
191
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
192
193
}

194
195
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
196
197
198
199
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
200
201
202
203
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
204
    # [Auto-converted (see adapters.py)]
205
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
206
207
}

208
_MULTIMODAL_MODELS = {
209
    # [Decoder-only]
210
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
211
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
212
213
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
214
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
215
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
216
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
217
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
218
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
219
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
220
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
221
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
222
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
223
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
224
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
225
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
226
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
227
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
228
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
229
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
230
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
231
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
232
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
233
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
234
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
235
236
237
238
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
239
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
240
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
241
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
242
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
243
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
244
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
245
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
246
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
247
    "Ovis": ("ovis", "Ovis"),
248
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
249
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
250
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
251
252
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
253
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
254
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
255
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
256
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
257
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
258
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
259
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
260
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
261
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
262
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
263
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
264
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
265
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
266
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
267
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
268
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
269
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
270
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
271
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
272
}
273
274

_SPECULATIVE_DECODING_MODELS = {
275
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
276
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
277
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
278
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
279
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
280
281
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
282
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
283
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
284
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
285
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
286
    "MedusaModel": ("medusa", "Medusa"),
287
288
289
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
290
}
291

292
_TRANSFORMERS_SUPPORTED_MODELS = {
293
294
295
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
296
297
298
299
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
300
    "TransformersModel": ("transformers", "TransformersModel"),
301
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
302
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
303
}
304
# yapf: enable
305

306
_VLLM_MODELS = {
307
    **_TEXT_GENERATION_MODELS,
308
    **_EMBEDDING_MODELS,
309
    **_CROSS_ENCODER_MODELS,
310
    **_MULTIMODAL_MODELS,
311
    **_SPECULATIVE_DECODING_MODELS,
312
313
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
314
315
}

316
317
318
319
320
321
322
323
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

324
325
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

326

327
328
@dataclass(frozen=True)
class _ModelInfo:
329
    architecture: str
330
    is_text_generation_model: bool
331
    is_pooling_model: bool
332
    default_pooling_type: str
333
    supports_cross_encoding: bool
334
    supports_multimodal: bool
335
    supports_multimodal_raw_input_only: bool
336
    supports_multimodal_encoder_tp_data: bool
337
    supports_pp: bool
338
339
    has_inner_state: bool
    is_attention_free: bool
340
    is_hybrid: bool
341
    has_noops: bool
342
    supports_transcription: bool
343
    supports_transcription_only: bool
344
    supports_v0_only: bool
345
346

    @staticmethod
347
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
348
        return _ModelInfo(
349
            architecture=model.__name__,
350
            is_text_generation_model=is_text_generation_model(model),
351
            is_pooling_model=is_pooling_model(model),
352
            default_pooling_type=get_default_pooling_type(model),
353
            supports_cross_encoding=supports_cross_encoding(model),
354
            supports_multimodal=supports_multimodal(model),
355
356
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
357
358
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
359
            supports_pp=supports_pp(model),
360
361
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
362
            is_hybrid=is_hybrid(model),
363
            supports_transcription=supports_transcription(model),
364
365
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
366
            supports_v0_only=supports_v0_only(model),
367
            has_noops=has_noops(model),
368
        )
369
370


371
class _BaseRegisteredModel(ABC):
372

373
374
375
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
376

377
    @abstractmethod
378
    def load_model_cls(self) -> type[nn.Module]:
379
        raise NotImplementedError
380
381


382
383
384
385
386
387
388
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
389
    model_cls: type[nn.Module]
390
391

    @staticmethod
392
    def from_model_cls(model_cls: type[nn.Module]):
393
394
395
396
397
398
399
400
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

401
    def load_model_cls(self) -> type[nn.Module]:
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

418
    def load_model_cls(self) -> type[nn.Module]:
419
420
421
422
423
424
425
426
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
427
) -> Optional[type[nn.Module]]:
428
    from vllm.platforms import current_platform
429
    current_platform.verify_model_arch(model_arch)
430
431
432
433
434
435
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
436
437


438
439
440
441
442
443
444
445
446
447
448
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
449
450


451
452
453
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
454
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
455

456
    def get_supported_archs(self) -> Set[str]:
457
        return self.models.keys()
458

459
460
461
    def register_model(
        self,
        model_arch: str,
462
        model_cls: Union[type[nn.Module], str],
463
    ) -> None:
464
465
466
        """
        Register an external model to be used in vLLM.

467
        `model_cls` can be either:
468

469
        - A [`torch.nn.Module`][] class directly referencing the model.
470
        - A string in the format `<module>:<class>` which can be used to
471
472
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
473
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
474
        """
475
476
477
478
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

479
        if model_arch in self.models:
480
481
482
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
483
484
485
486
487
488
489
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
490

491
            model = _LazyRegisteredModel(*split_str)
492
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
493
            model = _RegisteredModel.from_model_cls(model_cls)
494
495
496
497
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
498

499
        self.models[model_arch] = model
500

501
    def _raise_for_unsupported(self, architectures: list[str]):
502
        all_supported_archs = self.get_supported_archs()
503

504
505
506
507
508
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

509
510
511
512
513
514
515
516
517
518
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

519
520
521
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
522

523
    def _try_load_model_cls(self,
524
                            model_arch: str) -> Optional[type[nn.Module]]:
525
526
        if model_arch not in self.models:
            return None
527

528
        return _try_load_model_cls(model_arch, self.models[model_arch])
529

530
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
531
532
        if model_arch not in self.models:
            return None
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
591
                return None
592

593
594
595
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
596

597
        return model_config._get_transformers_backend_cls()
598

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
624

625
626
    def inspect_model_cls(
        self,
627
        architectures: Union[str, list[str]],
628
        model_config: ModelConfig,
629
    ) -> tuple[_ModelInfo, str]:
630
631
        if isinstance(architectures, str):
            architectures = [architectures]
632
633
        if not architectures:
            raise ValueError("No model architectures are specified")
634
635
636
637
638
639
640
641
642

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
643
644
645
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
646

647
648
649
650
651
652
653
654
655
656
657
658
659
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
660
            model_info = self._try_inspect_model_cls(normalized_arch)
661
            if model_info is not None:
662
                return (model_info, arch)
663

664
665
666
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
667
668
669
670
671
672
673
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

674
        return self._raise_for_unsupported(architectures)
675

676
677
    def resolve_model_cls(
        self,
678
        architectures: Union[str, list[str]],
679
        model_config: ModelConfig,
680
    ) -> tuple[type[nn.Module], str]:
681
682
        if isinstance(architectures, str):
            architectures = [architectures]
683
684
        if not architectures:
            raise ValueError("No model architectures are specified")
685
686
687
688
689
690
691
692
693

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
694
695
696
697
698
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
699

700
701
702
703
704
705
706
707
708
709
710
711
712
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
713
            model_cls = self._try_load_model_cls(normalized_arch)
714
715
            if model_cls is not None:
                return (model_cls, arch)
716

717
718
719
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
720
721
722
723
724
725
726
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

727
        return self._raise_for_unsupported(architectures)
728

729
730
    def is_text_generation_model(
        self,
731
        architectures: Union[str, list[str]],
732
        model_config: ModelConfig,
733
    ) -> bool:
734
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
735
        return model_cls.is_text_generation_model
736

737
    def is_pooling_model(
738
        self,
739
        architectures: Union[str, list[str]],
740
        model_config: ModelConfig,
741
    ) -> bool:
742
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
743
        return model_cls.is_pooling_model
744

745
746
    def is_cross_encoder_model(
        self,
747
        architectures: Union[str, list[str]],
748
        model_config: ModelConfig,
749
    ) -> bool:
750
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
751
        return model_cls.supports_cross_encoding
752

753
754
    def is_multimodal_model(
        self,
755
        architectures: Union[str, list[str]],
756
        model_config: ModelConfig,
757
    ) -> bool:
758
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
759
        return model_cls.supports_multimodal
760

761
    def is_multimodal_raw_input_only_model(
762
763
        self,
        architectures: Union[str, list[str]],
764
        model_config: ModelConfig,
765
    ) -> bool:
766
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
767
        return model_cls.supports_multimodal_raw_input_only
768

769
770
    def is_pp_supported_model(
        self,
771
        architectures: Union[str, list[str]],
772
        model_config: ModelConfig,
773
    ) -> bool:
774
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
775
        return model_cls.supports_pp
776

777
778
    def model_has_inner_state(
        self,
779
        architectures: Union[str, list[str]],
780
        model_config: ModelConfig,
781
    ) -> bool:
782
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
783
        return model_cls.has_inner_state
784

785
786
    def is_attention_free_model(
        self,
787
        architectures: Union[str, list[str]],
788
        model_config: ModelConfig,
789
    ) -> bool:
790
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
791
        return model_cls.is_attention_free
792

793
794
    def is_hybrid_model(
        self,
795
        architectures: Union[str, list[str]],
796
        model_config: ModelConfig,
797
    ) -> bool:
798
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
799
800
        return model_cls.is_hybrid

801
802
    def is_noops_model(
        self,
803
        architectures: Union[str, list[str]],
804
        model_config: ModelConfig,
805
    ) -> bool:
806
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
807
808
        return model_cls.has_noops

809
810
    def is_transcription_model(
        self,
811
        architectures: Union[str, list[str]],
812
        model_config: ModelConfig,
813
    ) -> bool:
814
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
815
816
        return model_cls.supports_transcription

817
818
819
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
820
        model_config: ModelConfig,
821
    ) -> bool:
822
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
823
824
        return model_cls.supports_transcription_only

825
826
    def is_v1_compatible(
        self,
827
        architectures: Union[str, list[str]],
828
        model_config: ModelConfig,
829
    ) -> bool:
830
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
831
832
        return not model_cls.supports_v0_only

833
834

ModelRegistry = _ModelRegistry({
835
836
    model_arch:
    _LazyRegisteredModel(
837
838
839
840
841
842
843
844
845
846
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
847
848
849
850
851
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

852
        # `cloudpickle` allows pickling lambda functions directly
853
        import cloudpickle
854
        input_bytes = cloudpickle.dumps((fn, output_filepath))
855
856
857

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
858
859
860
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
861
862
863
864
865
866
867
868
869

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

870
        with open(output_filepath, "rb") as f:
871
872
873
874
875
876
877
878
879
880
881
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
882
883
884

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
885
886
887


if __name__ == "__main__":
888
    _run()