registry.py 34.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19

import torch.nn as nn
20
import transformers
21

22
23
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
                         try_match_architecture_defaults)
24
from vllm.logger import init_logger
25
26
from vllm.transformers_utils.dynamic_module import (
    try_get_class_from_dynamic_module)
27

28
29
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
30
31
                         supports_multimodal, supports_multimodal_raw_input,
                         supports_pp, supports_transcription, supports_v0_only)
32
from .interfaces_base import is_pooling_model, is_text_generation_model
33
34
35

logger = init_logger(__name__)

36
# yapf: disable
37
38
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
39
40
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
41
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
42
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
43
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
44
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
45
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
46
47
48
49
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
50
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
51
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
52
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
53
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
54
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
55
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
56
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
57
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
58
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
59
60
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
61
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
62
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
63
64
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
65
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
66
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
67
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
68
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
69
70
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
71
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
72
73
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
74
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
75
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
76
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
77
78
79
80
81
82
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
83
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
84
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
85
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
86
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
87
88
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
89
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
90
91
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
92
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
93
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
94
95
96
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
97
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
98
99
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
100
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
101
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
102
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
103
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
104
105
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
106
107
108
109
110
111
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
112
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
113
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
114
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
115
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
116
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
117
118
119
120
121
122
123
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
124
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
125
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
126
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
127
128
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
129
130
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
131
132
133
134
135
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
136
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
137
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
138
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
139
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
140
141
142
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
143
144
145
}

_EMBEDDING_MODELS = {
146
    # [Text-only]
147
    "BertModel": ("bert", "BertEmbeddingModel"),
148
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
149
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
150
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
151
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
152
    "GritLM": ("gritlm", "GritLM"),
153
154
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
155
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
156
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
157
    "LlamaModel": ("llama", "LlamaForCausalLM"),
158
159
160
161
162
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
163
    "MistralModel": ("llama", "LlamaForCausalLM"),
164
    "ModernBertModel": ("modernbert", "ModernBertModel"),
165
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
166
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
167
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
168
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
169
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
170
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
171
172
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
173
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
174
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
175
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
176
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
177
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
178
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
179
180
181
182
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
183
184
}

185
186
187
188
189
190
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
191
192
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
193
    # [Auto-converted (see adapters.py)]
194
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
195
196
}

197
_MULTIMODAL_MODELS = {
198
    # [Decoder-only]
199
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
200
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
201
202
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
203
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
204
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
205
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
206
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
207
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
208
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
209
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
210
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
211
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
212
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
213
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
214
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
215
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
216
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
217
218
219
220
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
221
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
222
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
223
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
224
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
225
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
226
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
227
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
228
    "Ovis": ("ovis", "Ovis"),
229
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
230
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
231
232
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
233
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
234
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
235
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
236
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
237
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
238
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
239
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
240
    "UltravoxModel": ("ultravox", "UltravoxModel"),
汪志鹏's avatar
汪志鹏 committed
241
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
242
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
243
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
244
    # [Encoder-decoder]
245
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
246
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
247
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
248
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
249
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
250
}
251
252

_SPECULATIVE_DECODING_MODELS = {
253
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
254
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
255
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
256
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
257
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
258
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
259
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
260
    "MedusaModel": ("medusa", "Medusa"),
261
262
263
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
264
}
265

266
267
268
269
270
_TRANSFORMERS_SUPPORTED_MODELS = {
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
271
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
272
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
273
}
274
# yapf: enable
275

276
_VLLM_MODELS = {
277
    **_TEXT_GENERATION_MODELS,
278
    **_EMBEDDING_MODELS,
279
    **_CROSS_ENCODER_MODELS,
280
    **_MULTIMODAL_MODELS,
281
    **_SPECULATIVE_DECODING_MODELS,
282
283
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
284
285
}

286
287
288
289
290
291
292
293
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

294
295
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

296

297
298
@dataclass(frozen=True)
class _ModelInfo:
299
    architecture: str
300
    is_text_generation_model: bool
301
    is_pooling_model: bool
302
    supports_cross_encoding: bool
303
    supports_multimodal: bool
304
    supports_multimodal_raw_input: bool
305
    supports_pp: bool
306
307
    has_inner_state: bool
    is_attention_free: bool
308
    is_hybrid: bool
309
    has_noops: bool
310
    supports_transcription: bool
311
    supports_transcription_only: bool
312
    supports_v0_only: bool
313
314

    @staticmethod
315
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
316
        return _ModelInfo(
317
            architecture=model.__name__,
318
            is_text_generation_model=is_text_generation_model(model),
319
            is_pooling_model=is_pooling_model(model),
320
            supports_cross_encoding=supports_cross_encoding(model),
321
            supports_multimodal=supports_multimodal(model),
322
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
323
            supports_pp=supports_pp(model),
324
325
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
326
            is_hybrid=is_hybrid(model),
327
            supports_transcription=supports_transcription(model),
328
329
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
330
            supports_v0_only=supports_v0_only(model),
331
            has_noops=has_noops(model),
332
        )
333
334


335
class _BaseRegisteredModel(ABC):
336

337
338
339
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
340

341
    @abstractmethod
342
    def load_model_cls(self) -> type[nn.Module]:
343
        raise NotImplementedError
344
345


346
347
348
349
350
351
352
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
353
    model_cls: type[nn.Module]
354
355

    @staticmethod
356
    def from_model_cls(model_cls: type[nn.Module]):
357
358
359
360
361
362
363
364
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

365
    def load_model_cls(self) -> type[nn.Module]:
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

382
    def load_model_cls(self) -> type[nn.Module]:
383
384
385
386
387
388
389
390
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
391
) -> Optional[type[nn.Module]]:
392
    from vllm.platforms import current_platform
393
    current_platform.verify_model_arch(model_arch)
394
395
396
397
398
399
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
400
401


402
403
404
405
406
407
408
409
410
411
412
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
413
414


415
416
417
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
418
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
419

420
    def get_supported_archs(self) -> Set[str]:
421
        return self.models.keys()
422

423
424
425
    def register_model(
        self,
        model_arch: str,
426
        model_cls: Union[type[nn.Module], str],
427
    ) -> None:
428
429
430
        """
        Register an external model to be used in vLLM.

431
        `model_cls` can be either:
432

433
        - A [`torch.nn.Module`][] class directly referencing the model.
434
        - A string in the format `<module>:<class>` which can be used to
435
436
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
437
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
438
        """
439
440
441
442
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

443
        if model_arch in self.models:
444
445
446
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
447
448
449
450
451
452
453
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
454

455
            model = _LazyRegisteredModel(*split_str)
456
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
457
            model = _RegisteredModel.from_model_cls(model_cls)
458
459
460
461
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
462

463
        self.models[model_arch] = model
464

465
    def _raise_for_unsupported(self, architectures: list[str]):
466
        all_supported_archs = self.get_supported_archs()
467

468
469
470
471
472
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

473
474
475
476
477
478
479
480
481
482
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
                    "use this model architecture.")

483
484
485
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
486

487
    def _try_load_model_cls(self,
488
                            model_arch: str) -> Optional[type[nn.Module]]:
489
490
        if model_arch not in self.models:
            return None
491

492
        return _try_load_model_cls(model_arch, self.models[model_arch])
493

494
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
495
496
        if model_arch not in self.models:
            return None
497

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> Optional[str]:
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
                if model_config.model_impl != ModelImpl.TRANSFORMERS:
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
                    "'auto_map' (relevant if the model is custom).")

        if not model_module.is_backend_compatible():
            if model_config.model_impl != ModelImpl.TRANSFORMERS:
555
                return None
556

557
558
559
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
                "is not compatible with vLLM.")
560

561
        return model_config._get_transformers_backend_cls()
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
588

589
590
    def _normalize_archs(
        self,
591
592
        architectures: list[str],
        model_config: ModelConfig,
593
    ) -> list[str]:
594
595
596
        if not architectures:
            logger.warning("No model architectures are specified")

597
598
599
        return [
            self._normalize_arch(arch, model_config) for arch in architectures
        ]
600

601
602
    def inspect_model_cls(
        self,
603
        architectures: Union[str, list[str]],
604
        model_config: ModelConfig,
605
    ) -> tuple[_ModelInfo, str]:
606
607
        if isinstance(architectures, str):
            architectures = [architectures]
608

609
610
611
612
613
614
615
616
617
618
619
620
621
        normalized_archs = self._normalize_archs(architectures, model_config)

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch, normalized_arch in zip(architectures, normalized_archs):
            model_info = self._try_inspect_model_cls(normalized_arch)
622
            if model_info is not None:
623
                return (model_info, arch)
624

625
626
627
628
629
630
631
632
633
        # Fallback to transformers impl
        if model_config.model_impl in (ModelImpl.AUTO, ModelImpl.TRANSFORMERS):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

634
        return self._raise_for_unsupported(architectures)
635

636
637
    def resolve_model_cls(
        self,
638
        architectures: Union[str, list[str]],
639
        model_config: ModelConfig,
640
    ) -> tuple[type[nn.Module], str]:
641
642
        if isinstance(architectures, str):
            architectures = [architectures]
643

644
645
646
647
648
649
650
651
652
653
654
655
656
        normalized_archs = self._normalize_archs(architectures, model_config)

        # Require transformers impl
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch, normalized_arch in zip(architectures, normalized_archs):
            model_cls = self._try_load_model_cls(normalized_arch)
657
658
            if model_cls is not None:
                return (model_cls, arch)
659

660
661
662
663
664
665
666
667
668
        # Fallback to transformers impl
        if model_config.model_impl in (ModelImpl.AUTO, ModelImpl.TRANSFORMERS):
            arch = self._try_resolve_transformers(architectures[0],
                                                  model_config)
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

669
        return self._raise_for_unsupported(architectures)
670

671
672
    def is_text_generation_model(
        self,
673
        architectures: Union[str, list[str]],
674
        model_config: ModelConfig,
675
    ) -> bool:
676
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
677
        return model_cls.is_text_generation_model
678

679
    def is_pooling_model(
680
        self,
681
        architectures: Union[str, list[str]],
682
        model_config: ModelConfig,
683
    ) -> bool:
684
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
685
        return model_cls.is_pooling_model
686

687
688
    def is_cross_encoder_model(
        self,
689
        architectures: Union[str, list[str]],
690
        model_config: ModelConfig,
691
    ) -> bool:
692
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
693
        return model_cls.supports_cross_encoding
694

695
696
    def is_multimodal_model(
        self,
697
        architectures: Union[str, list[str]],
698
        model_config: ModelConfig,
699
    ) -> bool:
700
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
701
        return model_cls.supports_multimodal
702

703
704
705
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
706
        model_config: ModelConfig,
707
    ) -> bool:
708
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
709
710
        return model_cls.supports_multimodal_raw_input

711
712
    def is_pp_supported_model(
        self,
713
        architectures: Union[str, list[str]],
714
        model_config: ModelConfig,
715
    ) -> bool:
716
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
717
        return model_cls.supports_pp
718

719
720
    def model_has_inner_state(
        self,
721
        architectures: Union[str, list[str]],
722
        model_config: ModelConfig,
723
    ) -> bool:
724
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
725
        return model_cls.has_inner_state
726

727
728
    def is_attention_free_model(
        self,
729
        architectures: Union[str, list[str]],
730
        model_config: ModelConfig,
731
    ) -> bool:
732
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
733
        return model_cls.is_attention_free
734

735
736
    def is_hybrid_model(
        self,
737
        architectures: Union[str, list[str]],
738
        model_config: ModelConfig,
739
    ) -> bool:
740
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
741
742
        return model_cls.is_hybrid

743
744
    def is_noops_model(
        self,
745
        architectures: Union[str, list[str]],
746
        model_config: ModelConfig,
747
    ) -> bool:
748
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
749
750
        return model_cls.has_noops

751
752
    def is_transcription_model(
        self,
753
        architectures: Union[str, list[str]],
754
        model_config: ModelConfig,
755
    ) -> bool:
756
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
757
758
        return model_cls.supports_transcription

759
760
761
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
762
        model_config: ModelConfig,
763
    ) -> bool:
764
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
765
766
        return model_cls.supports_transcription_only

767
768
    def is_v1_compatible(
        self,
769
        architectures: Union[str, list[str]],
770
        model_config: ModelConfig,
771
    ) -> bool:
772
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
773
774
        return not model_cls.supports_v0_only

775
776

ModelRegistry = _ModelRegistry({
777
778
    model_arch:
    _LazyRegisteredModel(
779
780
781
782
783
784
785
786
787
788
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
789
790
791
792
793
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

794
        # `cloudpickle` allows pickling lambda functions directly
795
        import cloudpickle
796
        input_bytes = cloudpickle.dumps((fn, output_filepath))
797
798
799

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
800
801
802
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
803
804
805
806
807
808
809
810
811

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

812
        with open(output_filepath, "rb") as f:
813
814
815
816
817
818
819
820
821
822
823
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
824
825
826

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
827
828
829


if __name__ == "__main__":
830
    _run()