registry.py 36.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
                         is_attention_free, is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal, supports_multimodal_raw_input,
                         supports_pp, supports_transcription, supports_v0_only)
32
from .interfaces_base import is_pooling_model, is_text_generation_model
33
34
35

logger = init_logger(__name__)

36
# yapf: disable
37
38
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
39
40
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
41
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
42
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
43
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
44
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
45
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
46
47
48
49
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
50
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
51
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
52
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
53
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
54
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
55
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
56
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
57
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
58
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
59
60
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
61
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
62
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
63
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
64
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
65
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
66
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
67
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
68
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
69
70
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
71
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
72
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
73
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
74
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
75
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
76
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
77
78
79
80
81
82
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
83
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
84
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
85
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
86
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
87
88
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
89
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
90
91
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
92
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
93
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
94
95
96
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
97
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
98
99
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
100
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
101
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
102
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
103
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
104
105
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
106
107
108
109
110
111
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
112
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
113
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
114
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
115
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
116
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
117
118
119
120
121
122
123
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
124
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
125
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
126
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
127
128
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
129
130
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
131
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Song's avatar
Song committed
132
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
133
134
135
136
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
137
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
138
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
139
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
140
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
141
142
143
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
144
145
146
}

_EMBEDDING_MODELS = {
147
    # [Text-only]
148
    "BertModel": ("bert", "BertEmbeddingModel"),
149
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
150
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
151
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
152
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
153
    "GritLM": ("gritlm", "GritLM"),
154
155
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
156
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
157
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
158
    "LlamaModel": ("llama", "LlamaForCausalLM"),
159
160
161
162
163
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
164
    "MistralModel": ("llama", "LlamaForCausalLM"),
165
    "ModernBertModel": ("modernbert", "ModernBertModel"),
166
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
167
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
168
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
169
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
170
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
171
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
172
173
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
174
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
175
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
176
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
177
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
178
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
179
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
180
181
182
183
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
184
185
}

186
187
188
189
190
191
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
192
193
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
194
    # [Auto-converted (see adapters.py)]
195
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
196
197
}

198
_MULTIMODAL_MODELS = {
199
    # [Decoder-only]
200
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
201
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
202
203
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
204
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
205
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
206
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
207
    "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"),    # noqa: E501
208
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
209
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
210
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
211
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
212
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
213
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
214
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
215
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
216
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
217
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
218
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
219
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
220
221
222
223
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
224
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
225
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
226
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
227
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
228
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
229
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
230
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
231
    "Ovis": ("ovis", "Ovis"),
232
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
233
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
234
235
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
236
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
237
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
238
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
239
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
240
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
241
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
242
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
243
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Song's avatar
Song committed
244
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
245
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
246
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
247
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
248
    # [Encoder-decoder]
249
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
250
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
251
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
252
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
253
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
254
}
255
256

_SPECULATIVE_DECODING_MODELS = {
257
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
258
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
259
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
260
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
261
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
262
263
    # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
    # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
264
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
265
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
266
    "MedusaModel": ("medusa", "Medusa"),
267
268
269
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
270
}
271

272
273
274
275
276
_TRANSFORMERS_SUPPORTED_MODELS = {
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
277
    "TransformersModel": ("transformers", "TransformersModel"),
278
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
279
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
280
}
281
# yapf: enable
282

283
_VLLM_MODELS = {
284
    **_TEXT_GENERATION_MODELS,
285
    **_EMBEDDING_MODELS,
286
    **_CROSS_ENCODER_MODELS,
287
    **_MULTIMODAL_MODELS,
288
    **_SPECULATIVE_DECODING_MODELS,
289
290
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
291
292
}

293
294
295
296
297
298
299
300
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

301
302
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

303

304
305
@dataclass(frozen=True)
class _ModelInfo:
306
    architecture: str
307
    is_text_generation_model: bool
308
    is_pooling_model: bool
309
    default_pooling_type: str
310
    supports_cross_encoding: bool
311
    supports_multimodal: bool
312
    supports_multimodal_raw_input: bool
313
    supports_pp: bool
314
315
    has_inner_state: bool
    is_attention_free: bool
316
    is_hybrid: bool
317
    has_noops: bool
318
    supports_transcription: bool
319
    supports_transcription_only: bool
320
    supports_v0_only: bool
321
322

    @staticmethod
323
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
324
        return _ModelInfo(
325
            architecture=model.__name__,
326
            is_text_generation_model=is_text_generation_model(model),
327
            is_pooling_model=is_pooling_model(model),
328
            default_pooling_type=get_default_pooling_type(model),
329
            supports_cross_encoding=supports_cross_encoding(model),
330
            supports_multimodal=supports_multimodal(model),
331
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
332
            supports_pp=supports_pp(model),
333
334
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
335
            is_hybrid=is_hybrid(model),
336
            supports_transcription=supports_transcription(model),
337
338
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
339
            supports_v0_only=supports_v0_only(model),
340
            has_noops=has_noops(model),
341
        )
342
343


344
class _BaseRegisteredModel(ABC):
345

346
347
348
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
349

350
    @abstractmethod
351
    def load_model_cls(self) -> type[nn.Module]:
352
        raise NotImplementedError
353
354


355
356
357
358
359
360
361
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
362
    model_cls: type[nn.Module]
363
364

    @staticmethod
365
    def from_model_cls(model_cls: type[nn.Module]):
366
367
368
369
370
371
372
373
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

374
    def load_model_cls(self) -> type[nn.Module]:
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

391
    def load_model_cls(self) -> type[nn.Module]:
392
393
394
395
396
397
398
399
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
400
) -> Optional[type[nn.Module]]:
401
    from vllm.platforms import current_platform
402
    current_platform.verify_model_arch(model_arch)
403
404
405
406
407
408
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
409
410


411
412
413
414
415
416
417
418
419
420
421
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
422
423


424
425
426
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
427
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
428

429
    def get_supported_archs(self) -> Set[str]:
430
        return self.models.keys()
431

432
433
434
    def register_model(
        self,
        model_arch: str,
435
        model_cls: Union[type[nn.Module], str],
436
    ) -> None:
437
438
439
        """
        Register an external model to be used in vLLM.

440
        `model_cls` can be either:
441

442
        - A [`torch.nn.Module`][] class directly referencing the model.
443
        - A string in the format `<module>:<class>` which can be used to
444
445
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
446
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
447
        """
448
449
450
451
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

452
        if model_arch in self.models:
453
454
455
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
456
457
458
459
460
461
462
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
463

464
            model = _LazyRegisteredModel(*split_str)
465
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
466
            model = _RegisteredModel.from_model_cls(model_cls)
467
468
469
470
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
471

472
        self.models[model_arch] = model
473

474
    def _raise_for_unsupported(self, architectures: list[str]):
475
        all_supported_archs = self.get_supported_archs()
476

477
478
479
480
481
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

482
483
484
485
486
487
488
489
490
491
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

492
493
494
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
495

496
    def _try_load_model_cls(self,
497
                            model_arch: str) -> Optional[type[nn.Module]]:
498
499
        if model_arch not in self.models:
            return None
500

501
        return _try_load_model_cls(model_arch, self.models[model_arch])
502

503
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
504
505
        if model_arch not in self.models:
            return None
506

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
564
                return None
565

566
567
568
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
569

570
        return model_config._get_transformers_backend_cls()
571

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
597

598
599
    def inspect_model_cls(
        self,
600
        architectures: Union[str, list[str]],
601
        model_config: ModelConfig,
602
    ) -> tuple[_ModelInfo, str]:
603
604
        if isinstance(architectures, str):
            architectures = [architectures]
605
606
        if not architectures:
            raise ValueError("No model architectures are specified")
607
608
609
610
611
612
613
614
615
616

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

617
618
619
620
621
622
623
624
625
626
627
628
629
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
630
            model_info = self._try_inspect_model_cls(normalized_arch)
631
            if model_info is not None:
632
                return (model_info, arch)
633

634
635
636
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
637
638
639
640
641
642
643
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

644
        return self._raise_for_unsupported(architectures)
645

646
647
    def resolve_model_cls(
        self,
648
        architectures: Union[str, list[str]],
649
        model_config: ModelConfig,
650
    ) -> tuple[type[nn.Module], str]:
651
652
        if isinstance(architectures, str):
            architectures = [architectures]
653
654
        if not architectures:
            raise ValueError("No model architectures are specified")
655
656
657
658
659
660
661
662
663
664

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

665
666
667
668
669
670
671
672
673
674
675
676
677
        # Fallback to transformers impl (after resolving convert_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO
                and getattr(model_config, "convert_type", "none") == "none"):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
678
            model_cls = self._try_load_model_cls(normalized_arch)
679
680
            if model_cls is not None:
                return (model_cls, arch)
681

682
683
684
        # Fallback to transformers impl (before resolving runner_type)
        if (all(arch not in self.models for arch in architectures)
                and model_config.model_impl == ModelImpl.AUTO):
685
686
687
688
689
690
691
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

692
        return self._raise_for_unsupported(architectures)
693

694
695
    def is_text_generation_model(
        self,
696
        architectures: Union[str, list[str]],
697
        model_config: ModelConfig,
698
    ) -> bool:
699
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
700
        return model_cls.is_text_generation_model
701

702
    def is_pooling_model(
703
        self,
704
        architectures: Union[str, list[str]],
705
        model_config: ModelConfig,
706
    ) -> bool:
707
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
708
        return model_cls.is_pooling_model
709

710
711
    def is_cross_encoder_model(
        self,
712
        architectures: Union[str, list[str]],
713
        model_config: ModelConfig,
714
    ) -> bool:
715
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
716
        return model_cls.supports_cross_encoding
717

718
719
    def is_multimodal_model(
        self,
720
        architectures: Union[str, list[str]],
721
        model_config: ModelConfig,
722
    ) -> bool:
723
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
724
        return model_cls.supports_multimodal
725

726
727
728
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
729
        model_config: ModelConfig,
730
    ) -> bool:
731
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
732
733
        return model_cls.supports_multimodal_raw_input

734
735
    def is_pp_supported_model(
        self,
736
        architectures: Union[str, list[str]],
737
        model_config: ModelConfig,
738
    ) -> bool:
739
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
740
        return model_cls.supports_pp
741

742
743
    def model_has_inner_state(
        self,
744
        architectures: Union[str, list[str]],
745
        model_config: ModelConfig,
746
    ) -> bool:
747
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
748
        return model_cls.has_inner_state
749

750
751
    def is_attention_free_model(
        self,
752
        architectures: Union[str, list[str]],
753
        model_config: ModelConfig,
754
    ) -> bool:
755
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
756
        return model_cls.is_attention_free
757

758
759
    def is_hybrid_model(
        self,
760
        architectures: Union[str, list[str]],
761
        model_config: ModelConfig,
762
    ) -> bool:
763
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
764
765
        return model_cls.is_hybrid

766
767
    def is_noops_model(
        self,
768
        architectures: Union[str, list[str]],
769
        model_config: ModelConfig,
770
    ) -> bool:
771
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
772
773
        return model_cls.has_noops

774
775
    def is_transcription_model(
        self,
776
        architectures: Union[str, list[str]],
777
        model_config: ModelConfig,
778
    ) -> bool:
779
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
780
781
        return model_cls.supports_transcription

782
783
784
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
785
        model_config: ModelConfig,
786
    ) -> bool:
787
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
788
789
        return model_cls.supports_transcription_only

790
791
    def is_v1_compatible(
        self,
792
        architectures: Union[str, list[str]],
793
        model_config: ModelConfig,
794
    ) -> bool:
795
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
796
797
        return not model_cls.supports_v0_only

798
799

ModelRegistry = _ModelRegistry({
800
801
    model_arch:
    _LazyRegisteredModel(
802
803
804
805
806
807
808
809
810
811
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
812
813
814
815
816
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

817
        # `cloudpickle` allows pickling lambda functions directly
818
        import cloudpickle
819
        input_bytes = cloudpickle.dumps((fn, output_filepath))
820
821
822

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
823
824
825
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
826
827
828
829
830
831
832
833
834

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

835
        with open(output_filepath, "rb") as f:
836
837
838
839
840
841
842
843
844
845
846
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
847
848
849

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
850
851
852


if __name__ == "__main__":
853
    _run()