registry.py 38.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
55
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
56
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
57
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
58
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
59
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
60
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
62
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
63
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
64
65
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
66
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
67
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
68
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
69
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
70
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
71
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
72
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
73
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
74
75
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
76
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
77
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
78
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
79
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
80
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
81
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
82
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
83
84
85
86
87
88
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
89
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
90
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
91
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
92
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
93
94
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
95
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
96
97
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
98
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
99
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
100
101
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
102
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
103
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
104
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
105
106
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
107
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
108
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
109
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
110
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
111
112
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
113
114
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
115
    "MotifForCausalLM": ("motif", "MotifForCausalLM"),
116
117
118
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
119
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
120
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
121
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
122
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
123
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
124
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
125
126
127
128
129
130
131
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
132
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
133
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
134
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
135
136
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
137
138
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
139
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
140
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
141
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
142
143
144
145
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
146
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
147
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
148
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
149
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
150
151
152
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
153
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
154
155
156
}

_EMBEDDING_MODELS = {
157
    # [Text-only]
158
    "BertModel": ("bert", "BertEmbeddingModel"),
159
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
160
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
161
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
162
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
163
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
164
    "GritLM": ("gritlm", "GritLM"),
165
166
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
167
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
168
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
169
    "LlamaModel": ("llama", "LlamaForCausalLM"),
170
171
172
173
174
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
175
    "MistralModel": ("llama", "LlamaForCausalLM"),
176
    "ModernBertModel": ("modernbert", "ModernBertModel"),
177
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
178
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
179
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
180
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
181
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
182
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
183
184
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
185
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
186
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
187
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
188
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
189
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
190
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
191
192
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
193
    # models for the time being.
194
195
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
196
197
}

198
199
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
200
201
202
203
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
204
205
206
207
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
208
    # [Auto-converted (see adapters.py)]
209
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
210
211
}

212
_MULTIMODAL_MODELS = {
213
    # [Decoder-only]
214
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
215
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
216
217
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
218
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
219
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
220
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
221
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
222
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
223
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
224
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
225
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
226
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
227
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
228
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
229
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
230
    "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"),
Lyu Han's avatar
Lyu Han committed
231
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
232
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
233
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
234
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
235
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
236
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
237
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
238
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
239
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
240
241
242
243
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
244
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
245
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
246
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
247
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
248
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
249
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
250
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
251
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
252
    "Ovis": ("ovis", "Ovis"),
253
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
254
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
255
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
256
257
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
258
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
259
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
260
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
261
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
262
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
263
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
264
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
265
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
266
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
267
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
268
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
269
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
270
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
271
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
272
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
273
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
274
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
275
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
276
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
277
}
278
279

_SPECULATIVE_DECODING_MODELS = {
280
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
281
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
282
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
283
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
284
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
285
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
286
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
287
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
288
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
289
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
290
    "MedusaModel": ("medusa", "Medusa"),
291
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
292
293
294
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
295
}
296

297
_TRANSFORMERS_SUPPORTED_MODELS = {
298
299
300
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
301
302
303
304
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
305
    "TransformersModel": ("transformers", "TransformersModel"),
306
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
307
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
308
}
309
# yapf: enable
310

311
_VLLM_MODELS = {
312
    **_TEXT_GENERATION_MODELS,
313
    **_EMBEDDING_MODELS,
314
    **_CROSS_ENCODER_MODELS,
315
    **_MULTIMODAL_MODELS,
316
    **_SPECULATIVE_DECODING_MODELS,
317
318
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
319
320
}

321
322
323
324
325
326
327
328
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

329
330
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

331

332
333
@dataclass(frozen=True)
class _ModelInfo:
334
    architecture: str
335
    is_text_generation_model: bool
336
    is_pooling_model: bool
337
    default_pooling_type: str
338
    supports_cross_encoding: bool
339
    supports_multimodal: bool
340
    supports_multimodal_raw_input_only: bool
341
    supports_multimodal_encoder_tp_data: bool
342
    supports_pp: bool
343
344
    has_inner_state: bool
    is_attention_free: bool
345
    is_hybrid: bool
346
    has_noops: bool
347
    supports_transcription: bool
348
    supports_transcription_only: bool
349
    supports_v0_only: bool
350
351

    @staticmethod
352
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
353
        return _ModelInfo(
354
            architecture=model.__name__,
355
            is_text_generation_model=is_text_generation_model(model),
356
            is_pooling_model=is_pooling_model(model),
357
            default_pooling_type=get_default_pooling_type(model),
358
            supports_cross_encoding=supports_cross_encoding(model),
359
            supports_multimodal=supports_multimodal(model),
360
361
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
362
363
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
364
            supports_pp=supports_pp(model),
365
366
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
367
            is_hybrid=is_hybrid(model),
368
            supports_transcription=supports_transcription(model),
369
370
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
371
            supports_v0_only=supports_v0_only(model),
372
            has_noops=has_noops(model),
373
        )
374
375


376
class _BaseRegisteredModel(ABC):
377

378
379
380
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
381

382
    @abstractmethod
383
    def load_model_cls(self) -> type[nn.Module]:
384
        raise NotImplementedError
385
386


387
388
389
390
391
392
393
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
394
    model_cls: type[nn.Module]
395
396

    @staticmethod
397
    def from_model_cls(model_cls: type[nn.Module]):
398
399
400
401
402
403
404
405
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

406
    def load_model_cls(self) -> type[nn.Module]:
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

423
    def load_model_cls(self) -> type[nn.Module]:
424
425
426
427
428
429
430
431
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
432
) -> Optional[type[nn.Module]]:
433
    from vllm.platforms import current_platform
434
    current_platform.verify_model_arch(model_arch)
435
436
437
438
439
440
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
441
442


443
444
445
446
447
448
449
450
451
452
453
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
454
455


456
457
458
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
459
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
460

461
    def get_supported_archs(self) -> Set[str]:
462
        return self.models.keys()
463

464
465
466
    def register_model(
        self,
        model_arch: str,
467
        model_cls: Union[type[nn.Module], str],
468
    ) -> None:
469
470
471
        """
        Register an external model to be used in vLLM.

472
        `model_cls` can be either:
473

474
        - A [`torch.nn.Module`][] class directly referencing the model.
475
        - A string in the format `<module>:<class>` which can be used to
476
477
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
478
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
479
        """
480
481
482
483
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

484
        if model_arch in self.models:
485
486
487
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
488
489
490
491
492
493
494
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
495

496
            model = _LazyRegisteredModel(*split_str)
497
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
498
            model = _RegisteredModel.from_model_cls(model_cls)
499
500
501
502
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
503

504
        self.models[model_arch] = model
505

506
    def _raise_for_unsupported(self, architectures: list[str]):
507
        all_supported_archs = self.get_supported_archs()
508

509
510
511
512
513
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

514
515
516
517
518
519
520
521
522
523
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

524
525
526
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
527

528
    def _try_load_model_cls(self,
529
                            model_arch: str) -> Optional[type[nn.Module]]:
530
531
        if model_arch not in self.models:
            return None
532

533
        return _try_load_model_cls(model_arch, self.models[model_arch])
534

535
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
536
537
        if model_arch not in self.models:
            return None
538

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
596
                return None
597

598
599
600
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
601

602
        return model_config._get_transformers_backend_cls()
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
629

630
631
    def inspect_model_cls(
        self,
632
        architectures: Union[str, list[str]],
633
        model_config: ModelConfig,
634
    ) -> tuple[_ModelInfo, str]:
635
636
        if isinstance(architectures, str):
            architectures = [architectures]
637
638
        if not architectures:
            raise ValueError("No model architectures are specified")
639
640
641
642
643
644
645
646
647

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
648
649
650
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
651

652
653
654
655
656
657
658
659
660
661
662
663
664
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
665
            model_info = self._try_inspect_model_cls(normalized_arch)
666
            if model_info is not None:
667
                return (model_info, arch)
668

669
670
671
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
672
673
674
675
676
677
678
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

679
        return self._raise_for_unsupported(architectures)
680

681
682
    def resolve_model_cls(
        self,
683
        architectures: Union[str, list[str]],
684
        model_config: ModelConfig,
685
    ) -> tuple[type[nn.Module], str]:
686
687
        if isinstance(architectures, str):
            architectures = [architectures]
688
689
        if not architectures:
            raise ValueError("No model architectures are specified")
690
691
692
693
694
695
696
697
698

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
699
700
701
702
703
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
704

705
706
707
708
709
710
711
712
713
714
715
716
717
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
718
            model_cls = self._try_load_model_cls(normalized_arch)
719
720
            if model_cls is not None:
                return (model_cls, arch)
721

722
723
724
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
725
726
727
728
729
730
731
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

732
        return self._raise_for_unsupported(architectures)
733

734
735
    def is_text_generation_model(
        self,
736
        architectures: Union[str, list[str]],
737
        model_config: ModelConfig,
738
    ) -> bool:
739
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
740
        return model_cls.is_text_generation_model
741

742
    def is_pooling_model(
743
        self,
744
        architectures: Union[str, list[str]],
745
        model_config: ModelConfig,
746
    ) -> bool:
747
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
748
        return model_cls.is_pooling_model
749

750
751
    def is_cross_encoder_model(
        self,
752
        architectures: Union[str, list[str]],
753
        model_config: ModelConfig,
754
    ) -> bool:
755
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
756
        return model_cls.supports_cross_encoding
757

758
759
    def is_multimodal_model(
        self,
760
        architectures: Union[str, list[str]],
761
        model_config: ModelConfig,
762
    ) -> bool:
763
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
764
        return model_cls.supports_multimodal
765

766
    def is_multimodal_raw_input_only_model(
767
768
        self,
        architectures: Union[str, list[str]],
769
        model_config: ModelConfig,
770
    ) -> bool:
771
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
772
        return model_cls.supports_multimodal_raw_input_only
773

774
775
    def is_pp_supported_model(
        self,
776
        architectures: Union[str, list[str]],
777
        model_config: ModelConfig,
778
    ) -> bool:
779
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
780
        return model_cls.supports_pp
781

782
783
    def model_has_inner_state(
        self,
784
        architectures: Union[str, list[str]],
785
        model_config: ModelConfig,
786
    ) -> bool:
787
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
788
        return model_cls.has_inner_state
789

790
791
    def is_attention_free_model(
        self,
792
        architectures: Union[str, list[str]],
793
        model_config: ModelConfig,
794
    ) -> bool:
795
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
796
        return model_cls.is_attention_free
797

798
799
    def is_hybrid_model(
        self,
800
        architectures: Union[str, list[str]],
801
        model_config: ModelConfig,
802
    ) -> bool:
803
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
804
805
        return model_cls.is_hybrid

806
807
    def is_noops_model(
        self,
808
        architectures: Union[str, list[str]],
809
        model_config: ModelConfig,
810
    ) -> bool:
811
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
812
813
        return model_cls.has_noops

814
815
    def is_transcription_model(
        self,
816
        architectures: Union[str, list[str]],
817
        model_config: ModelConfig,
818
    ) -> bool:
819
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
820
821
        return model_cls.supports_transcription

822
823
824
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
825
        model_config: ModelConfig,
826
    ) -> bool:
827
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
828
829
        return model_cls.supports_transcription_only

830
831
    def is_v1_compatible(
        self,
832
        architectures: Union[str, list[str]],
833
        model_config: ModelConfig,
834
    ) -> bool:
835
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
836
837
        return not model_cls.supports_v0_only

838
839

ModelRegistry = _ModelRegistry({
840
841
    model_arch:
    _LazyRegisteredModel(
842
843
844
845
846
847
848
849
850
851
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
852
853
854
855
856
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

857
        # `cloudpickle` allows pickling lambda functions directly
858
        import cloudpickle
859
        input_bytes = cloudpickle.dumps((fn, output_filepath))
860
861
862

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
863
864
865
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
866
867
868
869
870
871
872
873
874

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

875
        with open(output_filepath, "rb") as f:
876
877
878
879
880
881
882
883
884
885
886
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
887
888
889

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
890
891
892


if __name__ == "__main__":
893
    _run()