registry.py 29.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import asdict, dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
25
26
                         supports_multimodal, supports_multimodal_raw_input,
                         supports_pp, supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
36
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
37
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
41
42
43
44
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
45
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
46
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
47
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
48
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
49
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
50
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
51
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
52
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
53
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
54
55
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
56
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
57
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
58
59
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
60
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
61
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
62
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
63
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
64
65
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
66
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
67
68
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
69
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
70
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
71
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
72
73
74
75
76
77
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
78
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
79
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
80
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
81
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
82
83
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
84
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
85
86
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
87
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
88
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
89
90
91
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
92
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),  # noqa: E501
93
94
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
95
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
96
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
97
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
98
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
99
100
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
101
102
103
104
105
106
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
107
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
108
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
109
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
110
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
111
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
112
113
114
115
116
117
118
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
119
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
120
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
121
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
122
123
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
124
125
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
126
127
128
129
130
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
131
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
132
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
133
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
134
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
135
136
137
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
138
139
140
}

_EMBEDDING_MODELS = {
141
    # [Text-only]
142
    "BertModel": ("bert", "BertEmbeddingModel"),
143
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
144
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
145
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
146
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
147
    "GritLM": ("gritlm", "GritLM"),
148
149
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
150
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
151
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
152
    "LlamaModel": ("llama", "LlamaForCausalLM"),
153
154
155
156
157
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
158
    "MistralModel": ("llama", "LlamaForCausalLM"),
159
    "ModernBertModel": ("modernbert", "ModernBertModel"),
160
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
161
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
162
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
163
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
164
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
165
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
166
167
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
168
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
169
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
170
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
171
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
172
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
173
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
174
175
176
177
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
178
179
}

180
181
182
183
184
185
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
186
187
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
188
    # [Auto-converted (see adapters.py)]
189
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
190
191
}

192
_MULTIMODAL_MODELS = {
193
    # [Decoder-only]
194
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
195
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
196
197
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
198
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
199
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
200
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
201
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
202
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
203
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
204
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
205
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
Lyu Han's avatar
Lyu Han committed
206
    "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"),  # noqa: E501
207
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
208
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
209
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
210
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
211
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
212
213
214
215
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
216
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
217
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
218
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
219
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
220
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
221
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
222
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
223
    "Ovis": ("ovis", "Ovis"),
224
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
225
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
226
227
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
    "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
228
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
229
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
230
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
231
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
232
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
233
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
234
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
235
    "UltravoxModel": ("ultravox", "UltravoxModel"),
汪志鹏's avatar
汪志鹏 committed
236
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
237
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
238
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
239
    # [Encoder-decoder]
240
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
241
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
242
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
243
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
244
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
245
}
246
247

_SPECULATIVE_DECODING_MODELS = {
248
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
249
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
250
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
251
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
252
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
253
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
254
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
255
    "MedusaModel": ("medusa", "Medusa"),
256
257
258
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
259
}
260

261
262
263
264
265
_TRANSFORMERS_SUPPORTED_MODELS = {
    "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"),  # noqa: E501
}

_TRANSFORMERS_BACKEND_MODELS = {
266
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
267
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
268
}
269
# yapf: enable
270

271
_VLLM_MODELS = {
272
    **_TEXT_GENERATION_MODELS,
273
    **_EMBEDDING_MODELS,
274
    **_CROSS_ENCODER_MODELS,
275
    **_MULTIMODAL_MODELS,
276
    **_SPECULATIVE_DECODING_MODELS,
277
278
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
279
280
}

281
282
283
284
285
286
287
288
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

289
290
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

291

292
293
@dataclass(frozen=True)
class _ModelInfo:
294
    architecture: str
295
    is_text_generation_model: bool
296
    is_pooling_model: bool
297
    supports_cross_encoding: bool
298
    supports_multimodal: bool
299
    supports_multimodal_raw_input: bool
300
    supports_pp: bool
301
302
    has_inner_state: bool
    is_attention_free: bool
303
    is_hybrid: bool
304
    has_noops: bool
305
    supports_transcription: bool
306
    supports_transcription_only: bool
307
    supports_v0_only: bool
308
309

    @staticmethod
310
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
311
        return _ModelInfo(
312
            architecture=model.__name__,
313
            is_text_generation_model=is_text_generation_model(model),
314
            is_pooling_model=True,  # Can convert any model into a pooling model
315
            supports_cross_encoding=supports_cross_encoding(model),
316
            supports_multimodal=supports_multimodal(model),
317
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
318
            supports_pp=supports_pp(model),
319
320
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
321
            is_hybrid=is_hybrid(model),
322
            supports_transcription=supports_transcription(model),
323
324
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
325
            supports_v0_only=supports_v0_only(model),
326
            has_noops=has_noops(model),
327
        )
328
329


330
class _BaseRegisteredModel(ABC):
331

332
333
334
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
335

336
    @abstractmethod
337
    def load_model_cls(self) -> type[nn.Module]:
338
        raise NotImplementedError
339
340


341
342
343
344
345
346
347
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
348
    model_cls: type[nn.Module]
349
350

    @staticmethod
351
    def from_model_cls(model_cls: type[nn.Module]):
352
353
354
355
356
357
358
359
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

360
    def load_model_cls(self) -> type[nn.Module]:
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

377
    def load_model_cls(self) -> type[nn.Module]:
378
379
380
381
382
383
384
385
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
386
) -> Optional[type[nn.Module]]:
387
    from vllm.platforms import current_platform
388
    current_platform.verify_model_arch(model_arch)
389
390
391
392
393
394
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
395
396


397
398
399
400
401
402
403
404
405
406
407
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
408
409


410
411
412
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
413
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
414

415
    def get_supported_archs(self) -> Set[str]:
416
        return self.models.keys()
417

418
419
420
    def register_model(
        self,
        model_arch: str,
421
        model_cls: Union[type[nn.Module], str],
422
    ) -> None:
423
424
425
        """
        Register an external model to be used in vLLM.

426
        `model_cls` can be either:
427

428
        - A [`torch.nn.Module`][] class directly referencing the model.
429
        - A string in the format `<module>:<class>` which can be used to
430
431
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
432
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
433
        """
434
435
436
437
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

438
        if model_arch in self.models:
439
440
441
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
442
443
444
445
446
447
448
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
449

450
            model = _LazyRegisteredModel(*split_str)
451
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
452
            model = _RegisteredModel.from_model_cls(model_cls)
453
454
455
456
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
457

458
        self.models[model_arch] = model
459

460
    def _raise_for_unsupported(self, architectures: list[str]):
461
        all_supported_archs = self.get_supported_archs()
462

463
464
465
466
467
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

468
469
470
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
471

472
    def _try_load_model_cls(self,
473
                            model_arch: str) -> Optional[type[nn.Module]]:
474
475
        if model_arch not in self.models:
            return None
476

477
        return _try_load_model_cls(model_arch, self.models[model_arch])
478

479
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
480
481
482
483
484
485
486
487
        if model_arch in self.models:
            return _try_inspect_model_cls(model_arch, self.models[model_arch])

        if model_arch.endswith("ForSequenceClassification"):
            causal_lm_arch = model_arch.replace("ForSequenceClassification",
                                                "ForCausalLM")
            if causal_lm_arch not in self.models:
                return None
488

489
490
491
492
493
494
495
496
497
498
499
            info = _try_inspect_model_cls(causal_lm_arch,
                                          self.models[causal_lm_arch])

            info = _ModelInfo(**dict(
                asdict(info), **{
                    "architecture": model_arch,
                    "supports_cross_encoding": True
                }))
            return info

        return None
500

501
502
    def _normalize_archs(
        self,
503
504
        architectures: Union[str, list[str]],
    ) -> list[str]:
505
506
507
508
509
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

510
511
512
513
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

514
515
516
517
518
519
520
521
522
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            if causal_lm_arch in self.models:
                normalized_arch.append(arch)

523
524
525
526
527
528
529
530
        # NOTE(Isotr0py): Be careful of architectures' order!
        # Make sure Transformers backend architecture is at the end of the
        # list, otherwise pooling models automatic conversion will fail!
        for arch in normalized_arch:
            if arch.startswith("TransformersFor"):
                normalized_arch.remove(arch)
                normalized_arch.append(arch)

531
        return normalized_arch
532

533
534
    def inspect_model_cls(
        self,
535
536
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
537
        architectures = self._normalize_archs(architectures)
538

539
540
541
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
542
                return (model_info, arch)
543

544
        return self._raise_for_unsupported(architectures)
545

546
547
    def resolve_model_cls(
        self,
548
549
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
550
        architectures = self._normalize_archs(architectures)
551

552
553
554
555
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
556

557
        return self._raise_for_unsupported(architectures)
558

559
560
    def is_text_generation_model(
        self,
561
        architectures: Union[str, list[str]],
562
    ) -> bool:
563
564
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
565

566
    def is_pooling_model(
567
        self,
568
        architectures: Union[str, list[str]],
569
    ) -> bool:
570
        model_cls, _ = self.inspect_model_cls(architectures)
571
        return model_cls.is_pooling_model
572

573
574
    def is_cross_encoder_model(
        self,
575
        architectures: Union[str, list[str]],
576
    ) -> bool:
577
578
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
579

580
581
    def is_multimodal_model(
        self,
582
        architectures: Union[str, list[str]],
583
    ) -> bool:
584
585
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
586

587
588
589
590
591
592
593
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal_raw_input

594
595
    def is_pp_supported_model(
        self,
596
        architectures: Union[str, list[str]],
597
    ) -> bool:
598
599
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
600

601
602
    def model_has_inner_state(
        self,
603
        architectures: Union[str, list[str]],
604
605
606
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
607

608
609
    def is_attention_free_model(
        self,
610
        architectures: Union[str, list[str]],
611
612
613
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
614

615
616
    def is_hybrid_model(
        self,
617
        architectures: Union[str, list[str]],
618
619
620
621
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

622
623
    def is_noops_model(
        self,
624
        architectures: Union[str, list[str]],
625
626
627
628
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

629
630
    def is_transcription_model(
        self,
631
        architectures: Union[str, list[str]],
632
633
634
635
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

636
637
638
639
640
641
642
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription_only

643
644
    def is_v1_compatible(
        self,
645
        architectures: Union[str, list[str]],
646
647
648
649
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

650
651

ModelRegistry = _ModelRegistry({
652
653
    model_arch:
    _LazyRegisteredModel(
654
655
656
657
658
659
660
661
662
663
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
664
665
666
667
668
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

669
        # `cloudpickle` allows pickling lambda functions directly
670
        import cloudpickle
671
        input_bytes = cloudpickle.dumps((fn, output_filepath))
672
673
674

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
675
676
677
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
678
679
680
681
682
683
684
685
686

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

687
        with open(output_filepath, "rb") as f:
688
689
690
691
692
693
694
695
696
697
698
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
699
700
701

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
702
703
704


if __name__ == "__main__":
705
    _run()