registry.py 38.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
55
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
56
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
57
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
58
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
59
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
60
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
62
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
63
64
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
65
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
66
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
67
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
68
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
69
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
70
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
71
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
72
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
73
74
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
75
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
76
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
77
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
78
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
79
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
80
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
81
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
82
83
84
85
86
87
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
88
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
89
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
90
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
91
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
92
93
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
94
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
95
96
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
97
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
98
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
99
100
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
101
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
102
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
103
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
104
105
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
106
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
107
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
108
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
109
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
110
111
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
112
113
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
114
    "MotifForCausalLM": ("motif", "MotifForCausalLM"),
115
116
117
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
118
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
119
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
120
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
121
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
122
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
123
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
124
125
126
127
128
129
130
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
131
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
132
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
133
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
134
135
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
136
137
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
138
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
139
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
140
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
141
142
143
144
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
145
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
146
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
147
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
148
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
149
150
151
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
152
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
153
154
155
}

_EMBEDDING_MODELS = {
156
    # [Text-only]
157
    "BertModel": ("bert", "BertEmbeddingModel"),
158
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
159
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
160
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
161
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
162
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
163
    "GritLM": ("gritlm", "GritLM"),
164
165
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
166
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
167
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
168
    "LlamaModel": ("llama", "LlamaForCausalLM"),
169
170
171
172
173
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
174
    "MistralModel": ("llama", "LlamaForCausalLM"),
175
    "ModernBertModel": ("modernbert", "ModernBertModel"),
176
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
177
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
178
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
179
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
180
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
181
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
182
183
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
184
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
185
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
186
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
187
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
188
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
189
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
190
191
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
192
    # models for the time being.
193
194
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
195
196
}

197
198
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
199
200
201
202
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
203
204
205
206
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
207
    # [Auto-converted (see adapters.py)]
208
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
209
210
}

211
_MULTIMODAL_MODELS = {
212
    # [Decoder-only]
213
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
214
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
215
216
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
217
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
218
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
219
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
220
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
221
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
222
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
223
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
224
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
225
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
226
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
227
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
228
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
229
    "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"),
Lyu Han's avatar
Lyu Han committed
230
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
231
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
232
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
233
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
234
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
235
    "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
236
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
237
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
238
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
239
240
241
242
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
243
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
244
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
245
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
246
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
247
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
248
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
249
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
250
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
251
    "Ovis": ("ovis", "Ovis"),
252
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
253
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
254
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
255
256
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
257
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
258
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
259
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
260
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
261
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
262
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
263
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
264
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
265
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
266
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
267
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
268
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
269
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
270
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
271
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
272
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
273
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
274
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
275
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
276
}
277
278

_SPECULATIVE_DECODING_MODELS = {
279
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
280
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
281
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
282
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
283
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
284
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
285
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
286
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
287
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
288
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
289
    "MedusaModel": ("medusa", "Medusa"),
290
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
291
292
293
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
294
}
295

296
_TRANSFORMERS_SUPPORTED_MODELS = {
297
298
299
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
300
301
302
303
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
304
    "TransformersModel": ("transformers", "TransformersModel"),
305
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
306
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
307
}
308
# yapf: enable
309

310
_VLLM_MODELS = {
311
    **_TEXT_GENERATION_MODELS,
312
    **_EMBEDDING_MODELS,
313
    **_CROSS_ENCODER_MODELS,
314
    **_MULTIMODAL_MODELS,
315
    **_SPECULATIVE_DECODING_MODELS,
316
317
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
318
319
}

320
321
322
323
324
325
326
327
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

328
329
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

330

331
332
@dataclass(frozen=True)
class _ModelInfo:
333
    architecture: str
334
    is_text_generation_model: bool
335
    is_pooling_model: bool
336
    default_pooling_type: str
337
    supports_cross_encoding: bool
338
    supports_multimodal: bool
339
    supports_multimodal_raw_input_only: bool
340
    supports_multimodal_encoder_tp_data: bool
341
    supports_pp: bool
342
343
    has_inner_state: bool
    is_attention_free: bool
344
    is_hybrid: bool
345
    has_noops: bool
346
    supports_transcription: bool
347
    supports_transcription_only: bool
348
    supports_v0_only: bool
349
350

    @staticmethod
351
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
352
        return _ModelInfo(
353
            architecture=model.__name__,
354
            is_text_generation_model=is_text_generation_model(model),
355
            is_pooling_model=is_pooling_model(model),
356
            default_pooling_type=get_default_pooling_type(model),
357
            supports_cross_encoding=supports_cross_encoding(model),
358
            supports_multimodal=supports_multimodal(model),
359
360
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
361
362
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
363
            supports_pp=supports_pp(model),
364
365
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
366
            is_hybrid=is_hybrid(model),
367
            supports_transcription=supports_transcription(model),
368
369
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
370
            supports_v0_only=supports_v0_only(model),
371
            has_noops=has_noops(model),
372
        )
373
374


375
class _BaseRegisteredModel(ABC):
376

377
378
379
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
380

381
    @abstractmethod
382
    def load_model_cls(self) -> type[nn.Module]:
383
        raise NotImplementedError
384
385


386
387
388
389
390
391
392
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
393
    model_cls: type[nn.Module]
394
395

    @staticmethod
396
    def from_model_cls(model_cls: type[nn.Module]):
397
398
399
400
401
402
403
404
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

405
    def load_model_cls(self) -> type[nn.Module]:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

422
    def load_model_cls(self) -> type[nn.Module]:
423
424
425
426
427
428
429
430
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
431
) -> Optional[type[nn.Module]]:
432
    from vllm.platforms import current_platform
433
    current_platform.verify_model_arch(model_arch)
434
435
436
437
438
439
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
440
441


442
443
444
445
446
447
448
449
450
451
452
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
453
454


455
456
457
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
458
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
459

460
    def get_supported_archs(self) -> Set[str]:
461
        return self.models.keys()
462

463
464
465
    def register_model(
        self,
        model_arch: str,
466
        model_cls: Union[type[nn.Module], str],
467
    ) -> None:
468
469
470
        """
        Register an external model to be used in vLLM.

471
        `model_cls` can be either:
472

473
        - A [`torch.nn.Module`][] class directly referencing the model.
474
        - A string in the format `<module>:<class>` which can be used to
475
476
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
477
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
478
        """
479
480
481
482
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

483
        if model_arch in self.models:
484
485
486
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
487
488
489
490
491
492
493
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
494

495
            model = _LazyRegisteredModel(*split_str)
496
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
497
            model = _RegisteredModel.from_model_cls(model_cls)
498
499
500
501
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
502

503
        self.models[model_arch] = model
504

505
    def _raise_for_unsupported(self, architectures: list[str]):
506
        all_supported_archs = self.get_supported_archs()
507

508
509
510
511
512
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

513
514
515
516
517
518
519
520
521
522
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

523
524
525
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
526

527
    def _try_load_model_cls(self,
528
                            model_arch: str) -> Optional[type[nn.Module]]:
529
530
        if model_arch not in self.models:
            return None
531

532
        return _try_load_model_cls(model_arch, self.models[model_arch])
533

534
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
535
536
        if model_arch not in self.models:
            return None
537

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
595
                return None
596

597
598
599
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
600

601
        return model_config._get_transformers_backend_cls()
602

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
628

629
630
    def inspect_model_cls(
        self,
631
        architectures: Union[str, list[str]],
632
        model_config: ModelConfig,
633
    ) -> tuple[_ModelInfo, str]:
634
635
        if isinstance(architectures, str):
            architectures = [architectures]
636
637
        if not architectures:
            raise ValueError("No model architectures are specified")
638
639
640
641
642
643
644
645
646

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
647
648
649
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
650

651
652
653
654
655
656
657
658
659
660
661
662
663
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
664
            model_info = self._try_inspect_model_cls(normalized_arch)
665
            if model_info is not None:
666
                return (model_info, arch)
667

668
669
670
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
671
672
673
674
675
676
677
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

678
        return self._raise_for_unsupported(architectures)
679

680
681
    def resolve_model_cls(
        self,
682
        architectures: Union[str, list[str]],
683
        model_config: ModelConfig,
684
    ) -> tuple[type[nn.Module], str]:
685
686
        if isinstance(architectures, str):
            architectures = [architectures]
687
688
        if not architectures:
            raise ValueError("No model architectures are specified")
689
690
691
692
693
694
695
696
697

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
698
699
700
701
702
        elif model_config.model_impl == ModelImpl.TERRATORCH:
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
703

704
705
706
707
708
709
710
711
712
713
714
715
716
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
717
            model_cls = self._try_load_model_cls(normalized_arch)
718
719
            if model_cls is not None:
                return (model_cls, arch)
720

721
722
723
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
724
725
726
727
728
729
730
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

731
        return self._raise_for_unsupported(architectures)
732

733
734
    def is_text_generation_model(
        self,
735
        architectures: Union[str, list[str]],
736
        model_config: ModelConfig,
737
    ) -> bool:
738
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
739
        return model_cls.is_text_generation_model
740

741
    def is_pooling_model(
742
        self,
743
        architectures: Union[str, list[str]],
744
        model_config: ModelConfig,
745
    ) -> bool:
746
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
747
        return model_cls.is_pooling_model
748

749
750
    def is_cross_encoder_model(
        self,
751
        architectures: Union[str, list[str]],
752
        model_config: ModelConfig,
753
    ) -> bool:
754
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
755
        return model_cls.supports_cross_encoding
756

757
758
    def is_multimodal_model(
        self,
759
        architectures: Union[str, list[str]],
760
        model_config: ModelConfig,
761
    ) -> bool:
762
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
763
        return model_cls.supports_multimodal
764

765
    def is_multimodal_raw_input_only_model(
766
767
        self,
        architectures: Union[str, list[str]],
768
        model_config: ModelConfig,
769
    ) -> bool:
770
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
771
        return model_cls.supports_multimodal_raw_input_only
772

773
774
    def is_pp_supported_model(
        self,
775
        architectures: Union[str, list[str]],
776
        model_config: ModelConfig,
777
    ) -> bool:
778
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
779
        return model_cls.supports_pp
780

781
782
    def model_has_inner_state(
        self,
783
        architectures: Union[str, list[str]],
784
        model_config: ModelConfig,
785
    ) -> bool:
786
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
787
        return model_cls.has_inner_state
788

789
790
    def is_attention_free_model(
        self,
791
        architectures: Union[str, list[str]],
792
        model_config: ModelConfig,
793
    ) -> bool:
794
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
795
        return model_cls.is_attention_free
796

797
798
    def is_hybrid_model(
        self,
799
        architectures: Union[str, list[str]],
800
        model_config: ModelConfig,
801
    ) -> bool:
802
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
803
804
        return model_cls.is_hybrid

805
806
    def is_noops_model(
        self,
807
        architectures: Union[str, list[str]],
808
        model_config: ModelConfig,
809
    ) -> bool:
810
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
811
812
        return model_cls.has_noops

813
814
    def is_transcription_model(
        self,
815
        architectures: Union[str, list[str]],
816
        model_config: ModelConfig,
817
    ) -> bool:
818
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
819
820
        return model_cls.supports_transcription

821
822
823
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
824
        model_config: ModelConfig,
825
    ) -> bool:
826
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
827
828
        return model_cls.supports_transcription_only

829
830
    def is_v1_compatible(
        self,
831
        architectures: Union[str, list[str]],
832
        model_config: ModelConfig,
833
    ) -> bool:
834
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
835
836
        return not model_cls.supports_v0_only

837
838

ModelRegistry = _ModelRegistry({
839
840
    model_arch:
    _LazyRegisteredModel(
841
842
843
844
845
846
847
848
849
850
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
851
852
853
854
855
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

856
        # `cloudpickle` allows pickling lambda functions directly
857
        import cloudpickle
858
        input_bytes = cloudpickle.dumps((fn, output_filepath))
859
860
861

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
862
863
864
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
865
866
867
868
869
870
871
872
873

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

874
        with open(output_filepath, "rb") as f:
875
876
877
878
879
880
881
882
883
884
885
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
886
887
888

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
889
890
891


if __name__ == "__main__":
892
    _run()