registry.py 36.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
                         is_attention_free, is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal, supports_multimodal_raw_input,
                         supports_pp, supports_transcription, supports_v0_only)
32
from .interfaces_base import is_pooling_model, is_text_generation_model
33
34
35

logger = init_logger(__name__)

36
# yapf: disable
37
38
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
39
40
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
41
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
42
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
43
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
44
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
45
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
46
47
48
49
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
50
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
51
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
52
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
53
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
54
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
55
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
56
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
57
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
58
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
59
60
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
61
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
62
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
63
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
64
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
65
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
66
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
67
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
68
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
69
70
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
71
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
72
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
73
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
74
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
75
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
76
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
77
78
79
80
81
82
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
83
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
84
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
85
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
86
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
87
88
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
89
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
90
91
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
92
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
93
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
94
95
96
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
97
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
98
99
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
100
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
101
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
102
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
103
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
104
105
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
106
107
108
109
110
111
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
112
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
113
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
114
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
115
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
116
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
117
118
119
120
121
122
123
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
124
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
125
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
126
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
127
128
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
129
130
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
131
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Song's avatar
Song committed
132
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
133
134
135
136
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
137
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
138
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
139
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
140
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
141
142
143
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
144
145
146
}

_EMBEDDING_MODELS = {
147
    # [Text-only]
148
    "BertModel": ("bert", "BertEmbeddingModel"),
149
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
150
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
151
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
152
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
153
    "GritLM": ("gritlm", "GritLM"),
154
155
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
156
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
157
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
158
    "LlamaModel": ("llama", "LlamaForCausalLM"),
159
160
161
162
163
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
164
    "MistralModel": ("llama", "LlamaForCausalLM"),
165
    "ModernBertModel": ("modernbert", "ModernBertModel"),
166
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
167
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
168
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
169
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
170
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
171
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
172
173
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
174
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
175
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
176
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
177
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
178
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
179
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
180
181
182
183
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
184
185
}

186
187
188
189
190
191
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
192
193
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
194
    # [Auto-converted (see adapters.py)]
195
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
196
197
}

198
_MULTIMODAL_MODELS = {
199
    # [Decoder-only]
200
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
201
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
202
203
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
204
    "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),  # noqa: E501
205
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
206
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
207
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
208
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
209
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
210
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
Jee Jee Li's avatar
Jee Jee Li committed
211
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),  # noqa: E501
212
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
213
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
214
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
215
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
216
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
217
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
218
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
219
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
220
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
221
222
223
224
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
225
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
226
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
227
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
228
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
229
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
230
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
231
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
232
    "Ovis": ("ovis", "Ovis"),
233
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
234
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
235
236
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
237
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
238
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
239
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
240
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
241
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
242
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
243
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
244
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
245
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
246
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
247
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
248
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
249
    # [Encoder-decoder]
250
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
251
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
252
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
253
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
254
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
255
}
256
257

_SPECULATIVE_DECODING_MODELS = {
258
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
259
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
260
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
261
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
262
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
263
264
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
265
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
266
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
267
    "MedusaModel": ("medusa", "Medusa"),
268
269
270
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
271
}
272

273
_TRANSFORMERS_SUPPORTED_MODELS = {
274
275
276
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
277
278
279
280
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
281
    "TransformersModel": ("transformers", "TransformersModel"),
282
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
283
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
284
}
285
# yapf: enable
286

287
_VLLM_MODELS = {
288
    **_TEXT_GENERATION_MODELS,
289
    **_EMBEDDING_MODELS,
290
    **_CROSS_ENCODER_MODELS,
291
    **_MULTIMODAL_MODELS,
292
    **_SPECULATIVE_DECODING_MODELS,
293
294
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
295
296
}

297
298
299
300
301
302
303
304
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

305
306
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

307

308
309
@dataclass(frozen=True)
class _ModelInfo:
310
    architecture: str
311
    is_text_generation_model: bool
312
    is_pooling_model: bool
313
    default_pooling_type: str
314
    supports_cross_encoding: bool
315
    supports_multimodal: bool
316
    supports_multimodal_raw_input: bool
317
    supports_pp: bool
318
319
    has_inner_state: bool
    is_attention_free: bool
320
    is_hybrid: bool
321
    has_noops: bool
322
    supports_transcription: bool
323
    supports_transcription_only: bool
324
    supports_v0_only: bool
325
326

    @staticmethod
327
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
328
        return _ModelInfo(
329
            architecture=model.__name__,
330
            is_text_generation_model=is_text_generation_model(model),
331
            is_pooling_model=is_pooling_model(model),
332
            default_pooling_type=get_default_pooling_type(model),
333
            supports_cross_encoding=supports_cross_encoding(model),
334
            supports_multimodal=supports_multimodal(model),
335
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
336
            supports_pp=supports_pp(model),
337
338
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
339
            is_hybrid=is_hybrid(model),
340
            supports_transcription=supports_transcription(model),
341
342
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
343
            supports_v0_only=supports_v0_only(model),
344
            has_noops=has_noops(model),
345
        )
346
347


348
class _BaseRegisteredModel(ABC):
349

350
351
352
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
353

354
    @abstractmethod
355
    def load_model_cls(self) -> type[nn.Module]:
356
        raise NotImplementedError
357
358


359
360
361
362
363
364
365
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
366
    model_cls: type[nn.Module]
367
368

    @staticmethod
369
    def from_model_cls(model_cls: type[nn.Module]):
370
371
372
373
374
375
376
377
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

378
    def load_model_cls(self) -> type[nn.Module]:
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

395
    def load_model_cls(self) -> type[nn.Module]:
396
397
398
399
400
401
402
403
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
404
) -> Optional[type[nn.Module]]:
405
    from vllm.platforms import current_platform
406
    current_platform.verify_model_arch(model_arch)
407
408
409
410
411
412
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
413
414


415
416
417
418
419
420
421
422
423
424
425
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
426
427


428
429
430
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
431
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
432

433
    def get_supported_archs(self) -> Set[str]:
434
        return self.models.keys()
435

436
437
438
    def register_model(
        self,
        model_arch: str,
439
        model_cls: Union[type[nn.Module], str],
440
    ) -> None:
441
442
443
        """
        Register an external model to be used in vLLM.

444
        `model_cls` can be either:
445

446
        - A [`torch.nn.Module`][] class directly referencing the model.
447
        - A string in the format `<module>:<class>` which can be used to
448
449
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
450
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
451
        """
452
453
454
455
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

456
        if model_arch in self.models:
457
458
459
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
460
461
462
463
464
465
466
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
467

468
            model = _LazyRegisteredModel(*split_str)
469
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
470
            model = _RegisteredModel.from_model_cls(model_cls)
471
472
473
474
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
475

476
        self.models[model_arch] = model
477

478
    def _raise_for_unsupported(self, architectures: list[str]):
479
        all_supported_archs = self.get_supported_archs()
480

481
482
483
484
485
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

486
487
488
489
490
491
492
493
494
495
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

496
497
498
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
499

500
    def _try_load_model_cls(self,
501
                            model_arch: str) -> Optional[type[nn.Module]]:
502
503
        if model_arch not in self.models:
            return None
504

505
        return _try_load_model_cls(model_arch, self.models[model_arch])
506

507
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
508
509
        if model_arch not in self.models:
            return None
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
568
                return None
569

570
571
572
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
573

574
        return model_config._get_transformers_backend_cls()
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
601

602
603
    def inspect_model_cls(
        self,
604
        architectures: Union[str, list[str]],
605
        model_config: ModelConfig,
606
    ) -> tuple[_ModelInfo, str]:
607
608
        if isinstance(architectures, str):
            architectures = [architectures]
609
610
        if not architectures:
            raise ValueError("No model architectures are specified")
611
612
613
614
615
616
617
618
619
620

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

621
622
623
624
625
626
627
628
629
630
631
632
633
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
634
            model_info = self._try_inspect_model_cls(normalized_arch)
635
            if model_info is not None:
636
                return (model_info, arch)
637

638
639
640
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
641
642
643
644
645
646
647
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

648
        return self._raise_for_unsupported(architectures)
649

650
651
    def resolve_model_cls(
        self,
652
        architectures: Union[str, list[str]],
653
        model_config: ModelConfig,
654
    ) -> tuple[type[nn.Module], str]:
655
656
        if isinstance(architectures, str):
            architectures = [architectures]
657
658
        if not architectures:
            raise ValueError("No model architectures are specified")
659
660
661
662
663
664
665
666
667
668

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

669
670
671
672
673
674
675
676
677
678
679
680
681
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
682
            model_cls = self._try_load_model_cls(normalized_arch)
683
684
            if model_cls is not None:
                return (model_cls, arch)
685

686
687
688
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
689
690
691
692
693
694
695
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

696
        return self._raise_for_unsupported(architectures)
697

698
699
    def is_text_generation_model(
        self,
700
        architectures: Union[str, list[str]],
701
        model_config: ModelConfig,
702
    ) -> bool:
703
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
704
        return model_cls.is_text_generation_model
705

706
    def is_pooling_model(
707
        self,
708
        architectures: Union[str, list[str]],
709
        model_config: ModelConfig,
710
    ) -> bool:
711
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
712
        return model_cls.is_pooling_model
713

714
715
    def is_cross_encoder_model(
        self,
716
        architectures: Union[str, list[str]],
717
        model_config: ModelConfig,
718
    ) -> bool:
719
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
720
        return model_cls.supports_cross_encoding
721

722
723
    def is_multimodal_model(
        self,
724
        architectures: Union[str, list[str]],
725
        model_config: ModelConfig,
726
    ) -> bool:
727
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
728
        return model_cls.supports_multimodal
729

730
731
732
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
733
        model_config: ModelConfig,
734
    ) -> bool:
735
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
736
737
        return model_cls.supports_multimodal_raw_input

738
739
    def is_pp_supported_model(
        self,
740
        architectures: Union[str, list[str]],
741
        model_config: ModelConfig,
742
    ) -> bool:
743
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
744
        return model_cls.supports_pp
745

746
747
    def model_has_inner_state(
        self,
748
        architectures: Union[str, list[str]],
749
        model_config: ModelConfig,
750
    ) -> bool:
751
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
752
        return model_cls.has_inner_state
753

754
755
    def is_attention_free_model(
        self,
756
        architectures: Union[str, list[str]],
757
        model_config: ModelConfig,
758
    ) -> bool:
759
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
760
        return model_cls.is_attention_free
761

762
763
    def is_hybrid_model(
        self,
764
        architectures: Union[str, list[str]],
765
        model_config: ModelConfig,
766
    ) -> bool:
767
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
768
769
        return model_cls.is_hybrid

770
771
    def is_noops_model(
        self,
772
        architectures: Union[str, list[str]],
773
        model_config: ModelConfig,
774
    ) -> bool:
775
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
776
777
        return model_cls.has_noops

778
779
    def is_transcription_model(
        self,
780
        architectures: Union[str, list[str]],
781
        model_config: ModelConfig,
782
    ) -> bool:
783
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
784
785
        return model_cls.supports_transcription

786
787
788
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
789
        model_config: ModelConfig,
790
    ) -> bool:
791
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
792
793
        return model_cls.supports_transcription_only

794
795
    def is_v1_compatible(
        self,
796
        architectures: Union[str, list[str]],
797
        model_config: ModelConfig,
798
    ) -> bool:
799
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
800
801
        return not model_cls.supports_v0_only

802
803

ModelRegistry = _ModelRegistry({
804
805
    model_arch:
    _LazyRegisteredModel(
806
807
808
809
810
811
812
813
814
815
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
816
817
818
819
820
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

821
        # `cloudpickle` allows pickling lambda functions directly
822
        import cloudpickle
823
        input_bytes = cloudpickle.dumps((fn, output_filepath))
824
825
826

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
827
828
829
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
830
831
832
833
834
835
836
837
838

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

839
        with open(output_filepath, "rb") as f:
840
841
842
843
844
845
846
847
848
849
850
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
851
852
853

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
854
855
856


if __name__ == "__main__":
857
    _run()