registry.py 38.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
55
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
56
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
57
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
58
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
59
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
60
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
62
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
63
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
64
65
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
66
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
67
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
68
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
69
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
70
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
71
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
72
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
73
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
74
75
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
76
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
77
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
78
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
79
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
80
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
81
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
82
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
83
84
85
86
87
88
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
89
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
90
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
91
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
92
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
93
94
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
95
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
96
97
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
98
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
99
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
100
101
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
102
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
103
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
104
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
105
106
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
107
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
108
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
109
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
110
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
111
112
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
113
114
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
115
    "MotifForCausalLM": ("motif", "MotifForCausalLM"),
116
117
118
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
119
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
120
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
121
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
122
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
123
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
124
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
125
126
127
128
129
130
131
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
132
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
133
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
134
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
135
136
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
137
138
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
139
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
140
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
141
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
142
143
144
145
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
146
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
147
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
148
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
149
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
150
151
152
}

_EMBEDDING_MODELS = {
153
    # [Text-only]
154
    "BertModel": ("bert", "BertEmbeddingModel"),
155
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
156
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
157
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
158
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
159
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
160
    "GritLM": ("gritlm", "GritLM"),
161
162
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
163
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
164
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
165
    "LlamaModel": ("llama", "LlamaForCausalLM"),
166
167
168
169
170
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
171
    "MistralModel": ("llama", "LlamaForCausalLM"),
172
    "ModernBertModel": ("modernbert", "ModernBertModel"),
173
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
174
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
175
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
176
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
177
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
178
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
179
180
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
181
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
182
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
183
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
184
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
185
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
186
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
187
188
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
189
    # models for the time being.
190
191
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
192
193
}

194
195
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
196
197
198
199
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
200
201
202
203
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
204
    # [Auto-converted (see adapters.py)]
205
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
206
207
}

208
_MULTIMODAL_MODELS = {
209
    # [Decoder-only]
210
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
211
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
212
213
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
214
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
215
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
216
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
217
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
218
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
219
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
220
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
221
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
222
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
223
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
224
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
225
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
226
    "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"),
Lyu Han's avatar
Lyu Han committed
227
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
228
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
229
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
230
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
231
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
232
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
233
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
234
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
235
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
236
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
237
238
239
240
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
241
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
242
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
243
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
244
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
245
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
246
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
247
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
248
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
249
    "Ovis": ("ovis", "Ovis"),
250
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
251
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
252
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
253
254
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
255
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
256
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
257
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
258
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
259
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
260
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
261
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
262
263
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
    "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),  # noqa: E501
264
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
265
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
266
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
267
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
268
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
269
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
270
    # [Encoder-decoder]
271
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
272
}
273
274

_SPECULATIVE_DECODING_MODELS = {
275
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
276
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
277
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
278
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
279
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
280
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
281
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
282
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
283
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
284
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
285
    "MedusaModel": ("medusa", "Medusa"),
286
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
287
288
289
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
290
}
291

292
_TRANSFORMERS_SUPPORTED_MODELS = {
293
294
295
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
296
297
298
299
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
300
    "TransformersModel": ("transformers", "TransformersModel"),
301
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
302
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
303
}
304
# yapf: enable
305

306
_VLLM_MODELS = {
307
    **_TEXT_GENERATION_MODELS,
308
    **_EMBEDDING_MODELS,
309
    **_CROSS_ENCODER_MODELS,
310
    **_MULTIMODAL_MODELS,
311
    **_SPECULATIVE_DECODING_MODELS,
312
313
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
314
315
}

316
317
318
319
320
321
322
323
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

324
325
326
327
328
329
330
331
332
333
334
_PREVIOUSLY_SUPPORTED_MODELS = {
    "Phi3SmallForCausalLM": "0.9.2",
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
335

336

337
338
@dataclass(frozen=True)
class _ModelInfo:
339
    architecture: str
340
    is_text_generation_model: bool
341
    is_pooling_model: bool
342
    default_pooling_type: str
343
    supports_cross_encoding: bool
344
    supports_multimodal: bool
345
    supports_multimodal_raw_input_only: bool
346
    supports_multimodal_encoder_tp_data: bool
347
    supports_pp: bool
348
349
    has_inner_state: bool
    is_attention_free: bool
350
    is_hybrid: bool
351
    has_noops: bool
352
    supports_transcription: bool
353
    supports_transcription_only: bool
354
    supports_v0_only: bool
355
356

    @staticmethod
357
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
358
        return _ModelInfo(
359
            architecture=model.__name__,
360
            is_text_generation_model=is_text_generation_model(model),
361
            is_pooling_model=is_pooling_model(model),
362
            default_pooling_type=get_default_pooling_type(model),
363
            supports_cross_encoding=supports_cross_encoding(model),
364
            supports_multimodal=supports_multimodal(model),
365
366
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
367
368
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
369
            supports_pp=supports_pp(model),
370
371
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
372
            is_hybrid=is_hybrid(model),
373
            supports_transcription=supports_transcription(model),
374
375
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
376
            supports_v0_only=supports_v0_only(model),
377
            has_noops=has_noops(model),
378
        )
379
380


381
class _BaseRegisteredModel(ABC):
382

383
384
385
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
386

387
    @abstractmethod
388
    def load_model_cls(self) -> type[nn.Module]:
389
        raise NotImplementedError
390
391


392
393
394
395
396
397
398
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
399
    model_cls: type[nn.Module]
400
401

    @staticmethod
402
    def from_model_cls(model_cls: type[nn.Module]):
403
404
405
406
407
408
409
410
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

411
    def load_model_cls(self) -> type[nn.Module]:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

428
    def load_model_cls(self) -> type[nn.Module]:
429
430
431
432
433
434
435
436
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
437
) -> Optional[type[nn.Module]]:
438
    from vllm.platforms import current_platform
439
    current_platform.verify_model_arch(model_arch)
440
441
442
443
444
445
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
446
447


448
449
450
451
452
453
454
455
456
457
458
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
459
460


461
462
463
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
464
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
465

466
    def get_supported_archs(self) -> Set[str]:
467
        return self.models.keys()
468

469
470
471
    def register_model(
        self,
        model_arch: str,
472
        model_cls: Union[type[nn.Module], str],
473
    ) -> None:
474
475
476
        """
        Register an external model to be used in vLLM.

477
        `model_cls` can be either:
478

479
        - A [`torch.nn.Module`][] class directly referencing the model.
480
        - A string in the format `<module>:<class>` which can be used to
481
482
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
483
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
484
        """
485
486
487
488
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

489
        if model_arch in self.models:
490
491
492
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
493
494
495
496
497
498
499
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
500

501
            model = _LazyRegisteredModel(*split_str)
502
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
503
            model = _RegisteredModel.from_model_cls(model_cls)
504
505
506
507
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
508

509
        self.models[model_arch] = model
510

511
    def _raise_for_unsupported(self, architectures: list[str]):
512
        all_supported_archs = self.get_supported_archs()
513

514
515
516
517
518
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

519
520
521
522
523
524
525
526
527
528
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

529
530
531
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
532

533
    def _try_load_model_cls(self,
534
                            model_arch: str) -> Optional[type[nn.Module]]:
535
536
        if model_arch not in self.models:
            return None
537

538
        return _try_load_model_cls(model_arch, self.models[model_arch])
539

540
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
541
542
        if model_arch not in self.models:
            return None
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
601
                return None
602

603
604
605
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
606

607
        return model_config._get_transformers_backend_cls()
608

609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
634

635
636
    def inspect_model_cls(
        self,
637
        architectures: Union[str, list[str]],
638
        model_config: ModelConfig,
639
    ) -> tuple[_ModelInfo, str]:
640
641
        if isinstance(architectures, str):
            architectures = [architectures]
642
643
        if not architectures:
            raise ValueError("No model architectures are specified")
644
645
646
647
648
649
650
651
652

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
653
654
655
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
656

657
658
659
660
661
662
663
664
665
666
667
668
669
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
670
            model_info = self._try_inspect_model_cls(normalized_arch)
671
            if model_info is not None:
672
                return (model_info, arch)
673

674
675
676
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
677
678
679
680
681
682
683
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

684
        return self._raise_for_unsupported(architectures)
685

686
687
    def resolve_model_cls(
        self,
688
        architectures: Union[str, list[str]],
689
        model_config: ModelConfig,
690
    ) -> tuple[type[nn.Module], str]:
691
692
        if isinstance(architectures, str):
            architectures = [architectures]
693
694
        if not architectures:
            raise ValueError("No model architectures are specified")
695
696
697
698
699
700
701
702
703

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
704
705
706
707
708
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
709

710
711
712
713
714
715
716
717
718
719
720
721
722
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
723
            model_cls = self._try_load_model_cls(normalized_arch)
724
725
            if model_cls is not None:
                return (model_cls, arch)
726

727
728
729
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
730
731
732
733
734
735
736
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

737
        return self._raise_for_unsupported(architectures)
738

739
740
    def is_text_generation_model(
        self,
741
        architectures: Union[str, list[str]],
742
        model_config: ModelConfig,
743
    ) -> bool:
744
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
745
        return model_cls.is_text_generation_model
746

747
    def is_pooling_model(
748
        self,
749
        architectures: Union[str, list[str]],
750
        model_config: ModelConfig,
751
    ) -> bool:
752
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
753
        return model_cls.is_pooling_model
754

755
756
    def is_cross_encoder_model(
        self,
757
        architectures: Union[str, list[str]],
758
        model_config: ModelConfig,
759
    ) -> bool:
760
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
761
        return model_cls.supports_cross_encoding
762

763
764
    def is_multimodal_model(
        self,
765
        architectures: Union[str, list[str]],
766
        model_config: ModelConfig,
767
    ) -> bool:
768
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
769
        return model_cls.supports_multimodal
770

771
    def is_multimodal_raw_input_only_model(
772
773
        self,
        architectures: Union[str, list[str]],
774
        model_config: ModelConfig,
775
    ) -> bool:
776
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
777
        return model_cls.supports_multimodal_raw_input_only
778

779
780
    def is_pp_supported_model(
        self,
781
        architectures: Union[str, list[str]],
782
        model_config: ModelConfig,
783
    ) -> bool:
784
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
785
        return model_cls.supports_pp
786

787
788
    def model_has_inner_state(
        self,
789
        architectures: Union[str, list[str]],
790
        model_config: ModelConfig,
791
    ) -> bool:
792
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
793
        return model_cls.has_inner_state
794

795
796
    def is_attention_free_model(
        self,
797
        architectures: Union[str, list[str]],
798
        model_config: ModelConfig,
799
    ) -> bool:
800
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
801
        return model_cls.is_attention_free
802

803
804
    def is_hybrid_model(
        self,
805
        architectures: Union[str, list[str]],
806
        model_config: ModelConfig,
807
    ) -> bool:
808
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
809
810
        return model_cls.is_hybrid

811
812
    def is_noops_model(
        self,
813
        architectures: Union[str, list[str]],
814
        model_config: ModelConfig,
815
    ) -> bool:
816
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
817
818
        return model_cls.has_noops

819
820
    def is_transcription_model(
        self,
821
        architectures: Union[str, list[str]],
822
        model_config: ModelConfig,
823
    ) -> bool:
824
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
825
826
        return model_cls.supports_transcription

827
828
829
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
830
        model_config: ModelConfig,
831
    ) -> bool:
832
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
833
834
        return model_cls.supports_transcription_only

835
836
    def is_v1_compatible(
        self,
837
        architectures: Union[str, list[str]],
838
        model_config: ModelConfig,
839
    ) -> bool:
840
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
841
842
        return not model_cls.supports_v0_only

843
844

ModelRegistry = _ModelRegistry({
845
846
    model_arch:
    _LazyRegisteredModel(
847
848
849
850
851
852
853
854
855
856
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
857
858
859
860
861
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

862
        # `cloudpickle` allows pickling lambda functions directly
863
        import cloudpickle
864
        input_bytes = cloudpickle.dumps((fn, output_filepath))
865
866
867

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
868
869
870
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
871
872
873
874
875
876
877
878
879

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

880
        with open(output_filepath, "rb") as f:
881
882
883
884
885
886
887
888
889
890
891
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
892
893
894

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
895
896
897


if __name__ == "__main__":
898
    _run()