registry.py 38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal,
                         supports_multimodal_encoder_tp_data,
32
                         supports_multimodal_raw_input_only, supports_pp,
33
                         supports_transcription, supports_v0_only)
34
35
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
                              is_text_generation_model)
36
37
38

logger = init_logger(__name__)

39
# yapf: disable
40
41
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
42
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
43
44
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
45
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
46
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
47
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
48
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
49
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
50
51
52
53
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
54
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
55
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
56
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
57
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
58
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
59
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
60
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
61
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
62
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
63
64
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
65
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
66
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
67
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
68
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
69
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
70
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
71
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
72
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
73
74
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
75
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
76
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
77
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
78
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
79
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
80
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
81
82
83
84
85
86
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
87
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
88
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
89
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
90
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
91
92
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
93
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
94
95
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
96
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
97
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
98
99
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
100
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
101
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
102
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
103
104
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
105
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
106
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
107
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
108
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
109
110
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
111
112
113
114
115
116
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
117
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
118
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
119
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
120
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
121
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
122
123
124
125
126
127
128
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
129
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
130
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
131
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
132
133
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
134
135
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
136
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
137
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Song's avatar
Song committed
138
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
139
140
141
142
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
143
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
144
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
145
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
147
148
149
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
汪志鹏's avatar
汪志鹏 committed
150
    "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
151
152
153
}

_EMBEDDING_MODELS = {
154
    # [Text-only]
155
    "BertModel": ("bert", "BertEmbeddingModel"),
156
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
157
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
158
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
159
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
160
    "GritLM": ("gritlm", "GritLM"),
161
162
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
163
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
164
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
165
    "LlamaModel": ("llama", "LlamaForCausalLM"),
166
167
168
169
170
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
171
    "MistralModel": ("llama", "LlamaForCausalLM"),
172
    "ModernBertModel": ("modernbert", "ModernBertModel"),
173
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
174
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
175
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
176
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
177
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
178
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
179
180
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
181
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
182
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
183
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
184
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
185
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
186
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
187
188
189
190
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
191
192
}

193
194
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
195
196
197
198
    "GteNewForSequenceClassification": ("bert_with_rope",
                                        "GteNewForSequenceClassification"),
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
199
200
201
202
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
203
    # [Auto-converted (see adapters.py)]
204
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
205
206
}

207
_MULTIMODAL_MODELS = {
208
    # [Decoder-only]
209
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
210
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
211
212
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
213
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
214
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
215
    "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"),  # noqa: E501
216
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
217
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
218
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
219
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
220
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
221
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
222
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
223
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
224
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
225
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
226
    "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
227
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
228
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
229
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
230
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
231
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
232
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
233
234
235
236
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
237
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
238
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
239
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
240
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
241
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
242
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
243
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
244
    "Ovis": ("ovis", "Ovis"),
245
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
246
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
247
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
248
249
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
250
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
251
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
252
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
253
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
254
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
255
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
256
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
257
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
258
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
259
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
260
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
261
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
262
    # [Encoder-decoder]
汪志鹏's avatar
汪志鹏 committed
263
    "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
264
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
265
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
266
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
267
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
268
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
269
}
270
271

_SPECULATIVE_DECODING_MODELS = {
272
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
273
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
274
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
275
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
276
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
277
278
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
279
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
280
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
281
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
282
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
283
    "MedusaModel": ("medusa", "Medusa"),
284
285
286
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
287
}
288

289
_TRANSFORMERS_SUPPORTED_MODELS = {
290
291
292
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
293
294
295
296
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
297
    "TransformersModel": ("transformers", "TransformersModel"),
298
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
299
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
300
}
301
# yapf: enable
302

303
_VLLM_MODELS = {
304
    **_TEXT_GENERATION_MODELS,
305
    **_EMBEDDING_MODELS,
306
    **_CROSS_ENCODER_MODELS,
307
    **_MULTIMODAL_MODELS,
308
    **_SPECULATIVE_DECODING_MODELS,
309
310
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
311
312
}

313
314
315
316
317
318
319
320
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

321
322
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

323

324
325
@dataclass(frozen=True)
class _ModelInfo:
326
    architecture: str
327
    is_text_generation_model: bool
328
    is_pooling_model: bool
329
    default_pooling_type: str
330
    supports_cross_encoding: bool
331
    supports_multimodal: bool
332
    supports_multimodal_raw_input_only: bool
333
    supports_multimodal_encoder_tp_data: bool
334
    supports_pp: bool
335
336
    has_inner_state: bool
    is_attention_free: bool
337
    is_hybrid: bool
338
    has_noops: bool
339
    supports_transcription: bool
340
    supports_transcription_only: bool
341
    supports_v0_only: bool
342
343

    @staticmethod
344
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
345
        return _ModelInfo(
346
            architecture=model.__name__,
347
            is_text_generation_model=is_text_generation_model(model),
348
            is_pooling_model=is_pooling_model(model),
349
            default_pooling_type=get_default_pooling_type(model),
350
            supports_cross_encoding=supports_cross_encoding(model),
351
            supports_multimodal=supports_multimodal(model),
352
353
            supports_multimodal_raw_input_only=
            supports_multimodal_raw_input_only(model),
354
355
            supports_multimodal_encoder_tp_data=
            supports_multimodal_encoder_tp_data(model),
356
            supports_pp=supports_pp(model),
357
358
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
359
            is_hybrid=is_hybrid(model),
360
            supports_transcription=supports_transcription(model),
361
362
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
363
            supports_v0_only=supports_v0_only(model),
364
            has_noops=has_noops(model),
365
        )
366
367


368
class _BaseRegisteredModel(ABC):
369

370
371
372
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
373

374
    @abstractmethod
375
    def load_model_cls(self) -> type[nn.Module]:
376
        raise NotImplementedError
377
378


379
380
381
382
383
384
385
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
386
    model_cls: type[nn.Module]
387
388

    @staticmethod
389
    def from_model_cls(model_cls: type[nn.Module]):
390
391
392
393
394
395
396
397
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

398
    def load_model_cls(self) -> type[nn.Module]:
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

415
    def load_model_cls(self) -> type[nn.Module]:
416
417
418
419
420
421
422
423
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
424
) -> Optional[type[nn.Module]]:
425
    from vllm.platforms import current_platform
426
    current_platform.verify_model_arch(model_arch)
427
428
429
430
431
432
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
433
434


435
436
437
438
439
440
441
442
443
444
445
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
446
447


448
449
450
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
451
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
452

453
    def get_supported_archs(self) -> Set[str]:
454
        return self.models.keys()
455

456
457
458
    def register_model(
        self,
        model_arch: str,
459
        model_cls: Union[type[nn.Module], str],
460
    ) -> None:
461
462
463
        """
        Register an external model to be used in vLLM.

464
        `model_cls` can be either:
465

466
        - A [`torch.nn.Module`][] class directly referencing the model.
467
        - A string in the format `<module>:<class>` which can be used to
468
469
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
470
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
471
        """
472
473
474
475
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

476
        if model_arch in self.models:
477
478
479
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
480
481
482
483
484
485
486
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
487

488
            model = _LazyRegisteredModel(*split_str)
489
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
490
            model = _RegisteredModel.from_model_cls(model_cls)
491
492
493
494
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
495

496
        self.models[model_arch] = model
497

498
    def _raise_for_unsupported(self, architectures: list[str]):
499
        all_supported_archs = self.get_supported_archs()
500

501
502
503
504
505
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

506
507
508
509
510
511
512
513
514
515
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

516
517
518
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
519

520
    def _try_load_model_cls(self,
521
                            model_arch: str) -> Optional[type[nn.Module]]:
522
523
        if model_arch not in self.models:
            return None
524

525
        return _try_load_model_cls(model_arch, self.models[model_arch])
526

527
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
528
529
        if model_arch not in self.models:
            return None
530

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
588
                return None
589

590
591
592
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
593

594
        return model_config._get_transformers_backend_cls()
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
621

622
623
    def inspect_model_cls(
        self,
624
        architectures: Union[str, list[str]],
625
        model_config: ModelConfig,
626
    ) -> tuple[_ModelInfo, str]:
627
628
        if isinstance(architectures, str):
            architectures = [architectures]
629
630
        if not architectures:
            raise ValueError("No model architectures are specified")
631
632
633
634
635
636
637
638
639
640

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

641
642
643
644
645
646
647
648
649
650
651
652
653
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
654
            model_info = self._try_inspect_model_cls(normalized_arch)
655
            if model_info is not None:
656
                return (model_info, arch)
657

658
659
660
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
661
662
663
664
665
666
667
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

668
        return self._raise_for_unsupported(architectures)
669

670
671
    def resolve_model_cls(
        self,
672
        architectures: Union[str, list[str]],
673
        model_config: ModelConfig,
674
    ) -> tuple[type[nn.Module], str]:
675
676
        if isinstance(architectures, str):
            architectures = [architectures]
677
678
        if not architectures:
            raise ValueError("No model architectures are specified")
679
680
681
682
683
684
685
686
687
688

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

689
690
691
692
693
694
695
696
697
698
699
700
701
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
702
            model_cls = self._try_load_model_cls(normalized_arch)
703
704
            if model_cls is not None:
                return (model_cls, arch)
705

706
707
708
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
709
710
711
712
713
714
715
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

716
        return self._raise_for_unsupported(architectures)
717

718
719
    def is_text_generation_model(
        self,
720
        architectures: Union[str, list[str]],
721
        model_config: ModelConfig,
722
    ) -> bool:
723
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
724
        return model_cls.is_text_generation_model
725

726
    def is_pooling_model(
727
        self,
728
        architectures: Union[str, list[str]],
729
        model_config: ModelConfig,
730
    ) -> bool:
731
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
732
        return model_cls.is_pooling_model
733

734
735
    def is_cross_encoder_model(
        self,
736
        architectures: Union[str, list[str]],
737
        model_config: ModelConfig,
738
    ) -> bool:
739
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
740
        return model_cls.supports_cross_encoding
741

742
743
    def is_multimodal_model(
        self,
744
        architectures: Union[str, list[str]],
745
        model_config: ModelConfig,
746
    ) -> bool:
747
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
748
        return model_cls.supports_multimodal
749

750
    def is_multimodal_raw_input_only_model(
751
752
        self,
        architectures: Union[str, list[str]],
753
        model_config: ModelConfig,
754
    ) -> bool:
755
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
756
        return model_cls.supports_multimodal_raw_input_only
757

758
759
    def is_pp_supported_model(
        self,
760
        architectures: Union[str, list[str]],
761
        model_config: ModelConfig,
762
    ) -> bool:
763
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
764
        return model_cls.supports_pp
765

766
767
    def model_has_inner_state(
        self,
768
        architectures: Union[str, list[str]],
769
        model_config: ModelConfig,
770
    ) -> bool:
771
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
772
        return model_cls.has_inner_state
773

774
775
    def is_attention_free_model(
        self,
776
        architectures: Union[str, list[str]],
777
        model_config: ModelConfig,
778
    ) -> bool:
779
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
780
        return model_cls.is_attention_free
781

782
783
    def is_hybrid_model(
        self,
784
        architectures: Union[str, list[str]],
785
        model_config: ModelConfig,
786
    ) -> bool:
787
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
788
789
        return model_cls.is_hybrid

790
791
    def is_noops_model(
        self,
792
        architectures: Union[str, list[str]],
793
        model_config: ModelConfig,
794
    ) -> bool:
795
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
796
797
        return model_cls.has_noops

798
799
    def is_transcription_model(
        self,
800
        architectures: Union[str, list[str]],
801
        model_config: ModelConfig,
802
    ) -> bool:
803
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
804
805
        return model_cls.supports_transcription

806
807
808
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
809
        model_config: ModelConfig,
810
    ) -> bool:
811
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
812
813
        return model_cls.supports_transcription_only

814
815
    def is_v1_compatible(
        self,
816
        architectures: Union[str, list[str]],
817
        model_config: ModelConfig,
818
    ) -> bool:
819
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
820
821
        return not model_cls.supports_v0_only

822
823

ModelRegistry = _ModelRegistry({
824
825
    model_arch:
    _LazyRegisteredModel(
826
827
828
829
830
831
832
833
834
835
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
836
837
838
839
840
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

841
        # `cloudpickle` allows pickling lambda functions directly
842
        import cloudpickle
843
        input_bytes = cloudpickle.dumps((fn, output_filepath))
844
845
846

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
847
848
849
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
850
851
852
853
854
855
856
857
858

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

859
        with open(output_filepath, "rb") as f:
860
861
862
863
864
865
866
867
868
869
870
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
871
872
873

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
874
875
876


if __name__ == "__main__":
877
    _run()