registry.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import asdict, dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
25
26
                         supports_multimodal, supports_multimodal_raw_input,
                         supports_pp, supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
36
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
37
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
41
42
43
44
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
45
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
46
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
47
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
48
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
49
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
50
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
51
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
52
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
53
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
54
55
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
56
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
57
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
58
59
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
60
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
61
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
62
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
63
    "FM9GForCausalLM": ("fm9g", "FM9GForCausalLM"),
64
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
65
66
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
67
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
68
69
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
70
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
71
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
72
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
73
74
75
76
77
78
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
79
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
80
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
81
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
82
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
83
84
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
85
86
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
87
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
88
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
89
90
91
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
92
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
93
94
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
95
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
96
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
97
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
98
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
99
100
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
101
102
103
104
105
106
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
107
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
108
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
109
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
110
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
111
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
112
113
114
115
116
117
118
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
119
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
120
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
121
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
122
123
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
124
125
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
126
127
128
129
130
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
131
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
132
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
133
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
134
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
135
136
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
137
138
139
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
140
141
142
}

_EMBEDDING_MODELS = {
143
    # [Text-only]
144
    "BertModel": ("bert", "BertEmbeddingModel"),
145
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
146
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
147
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
148
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
149
    "GritLM": ("gritlm", "GritLM"),
150
151
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
152
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
153
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
154
    "LlamaModel": ("llama", "LlamaForCausalLM"),
155
156
157
158
159
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
160
    "MistralModel": ("llama", "LlamaForCausalLM"),
161
    "ModernBertModel": ("modernbert", "ModernBertModel"),
162
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
163
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
164
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
165
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
166
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
167
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
168
169
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
170
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
171
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
172
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
173
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
174
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
175
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
176
177
178
179
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
180
181
}

182
183
184
185
186
187
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
188
189
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
190
    # [Auto-converted (see adapters.py)]
191
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
192
193
}

194
_MULTIMODAL_MODELS = {
195
    # [Decoder-only]
196
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
197
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
198
199
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
200
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
201
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
202
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
203
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
204
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
205
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
206
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
207
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
208
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
209
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
210
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
211
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
212
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
213
214
215
216
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
217
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
218
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
219
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
220
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
221
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
222
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
223
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
224
    "Ovis": ("ovis", "Ovis"),
225
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
226
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
227
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
228
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
229
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
230
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
231
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
232
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
233
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
234
    "UltravoxModel": ("ultravox", "UltravoxModel"),
235
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
236
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
237
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
238
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
239
    # [Encoder-decoder]
240
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
241
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
242
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
243
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
244
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
245
}
246
247

_SPECULATIVE_DECODING_MODELS = {
248
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
249
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
250
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
251
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
252
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
253
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
zhuwenwen's avatar
zhuwenwen committed
254
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
255
    "MedusaModel": ("medusa", "Medusa"),
256
257
258
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
259
}
260

261
_TRANSFORMERS_MODELS = {
262
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
263
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
264
}
265
# yapf: enable
266

267
_VLLM_MODELS = {
268
    **_TEXT_GENERATION_MODELS,
269
    **_EMBEDDING_MODELS,
270
    **_CROSS_ENCODER_MODELS,
271
    **_MULTIMODAL_MODELS,
272
    **_SPECULATIVE_DECODING_MODELS,
273
    **_TRANSFORMERS_MODELS,
274
275
}

276
277
278
279
280
281
282
283
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

284
285
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

286

287
288
@dataclass(frozen=True)
class _ModelInfo:
289
    architecture: str
290
    is_text_generation_model: bool
291
    is_pooling_model: bool
292
    supports_cross_encoding: bool
293
    supports_multimodal: bool
294
    supports_multimodal_raw_input: bool
295
    supports_pp: bool
296
297
    has_inner_state: bool
    is_attention_free: bool
298
    is_hybrid: bool
299
    has_noops: bool
300
    supports_transcription: bool
301
    supports_transcription_only: bool
302
    supports_v0_only: bool
303
304

    @staticmethod
305
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
306
        return _ModelInfo(
307
            architecture=model.__name__,
308
            is_text_generation_model=is_text_generation_model(model),
309
            is_pooling_model=True,  # Can convert any model into a pooling model
310
            supports_cross_encoding=supports_cross_encoding(model),
311
            supports_multimodal=supports_multimodal(model),
312
            supports_multimodal_raw_input=supports_multimodal_raw_input(model),
313
            supports_pp=supports_pp(model),
314
315
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
316
            is_hybrid=is_hybrid(model),
317
            supports_transcription=supports_transcription(model),
318
319
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
320
            supports_v0_only=supports_v0_only(model),
321
            has_noops=has_noops(model),
322
        )
323
324


325
class _BaseRegisteredModel(ABC):
326

327
328
329
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
330

331
    @abstractmethod
332
    def load_model_cls(self) -> type[nn.Module]:
333
        raise NotImplementedError
334
335


336
337
338
339
340
341
342
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
343
    model_cls: type[nn.Module]
344
345

    @staticmethod
346
    def from_model_cls(model_cls: type[nn.Module]):
347
348
349
350
351
352
353
354
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

355
    def load_model_cls(self) -> type[nn.Module]:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

372
    def load_model_cls(self) -> type[nn.Module]:
373
374
375
376
377
378
379
380
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
381
) -> Optional[type[nn.Module]]:
382
    from vllm.platforms import current_platform
383
    current_platform.verify_model_arch(model_arch)
384
385
386
387
388
389
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
390
391


392
393
394
395
396
397
398
399
400
401
402
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
403
404


405
406
407
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
408
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
409

410
    def get_supported_archs(self) -> Set[str]:
411
        return self.models.keys()
412

413
414
415
    def register_model(
        self,
        model_arch: str,
416
        model_cls: Union[type[nn.Module], str],
417
    ) -> None:
418
419
420
        """
        Register an external model to be used in vLLM.

421
        `model_cls` can be either:
422

423
        - A [`torch.nn.Module`][] class directly referencing the model.
424
        - A string in the format `<module>:<class>` which can be used to
425
426
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
427
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
428
        """
429
430
431
432
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

433
        if model_arch in self.models:
434
435
436
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
437
438
439
440
441
442
443
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
444

445
            model = _LazyRegisteredModel(*split_str)
446
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
447
            model = _RegisteredModel.from_model_cls(model_cls)
448
449
450
451
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
452

453
        self.models[model_arch] = model
454

455
    def _raise_for_unsupported(self, architectures: list[str]):
456
        all_supported_archs = self.get_supported_archs()
457

458
459
460
461
462
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

463
464
465
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
466

467
    def _try_load_model_cls(self,
468
                            model_arch: str) -> Optional[type[nn.Module]]:
469
470
        if model_arch not in self.models:
            return None
471

472
        return _try_load_model_cls(model_arch, self.models[model_arch])
473

474
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
475
476
477
478
479
480
481
482
        if model_arch in self.models:
            return _try_inspect_model_cls(model_arch, self.models[model_arch])

        if model_arch.endswith("ForSequenceClassification"):
            causal_lm_arch = model_arch.replace("ForSequenceClassification",
                                                "ForCausalLM")
            if causal_lm_arch not in self.models:
                return None
483

484
485
486
487
488
489
490
491
492
493
494
            info = _try_inspect_model_cls(causal_lm_arch,
                                          self.models[causal_lm_arch])

            info = _ModelInfo(**dict(
                asdict(info), **{
                    "architecture": model_arch,
                    "supports_cross_encoding": True
                }))
            return info

        return None
495

496
497
    def _normalize_archs(
        self,
498
499
        architectures: Union[str, list[str]],
    ) -> list[str]:
500
501
502
503
504
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

505
506
507
508
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

509
510
511
512
513
514
515
516
517
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            if causal_lm_arch in self.models:
                normalized_arch.append(arch)

518
519
520
521
522
523
524
525
        # NOTE(Isotr0py): Be careful of architectures' order!
        # Make sure Transformers backend architecture is at the end of the
        # list, otherwise pooling models automatic conversion will fail!
        for arch in normalized_arch:
            if arch.startswith("TransformersFor"):
                normalized_arch.remove(arch)
                normalized_arch.append(arch)

526
        return normalized_arch
527

528
529
    def inspect_model_cls(
        self,
530
531
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
532
        architectures = self._normalize_archs(architectures)
533

534
535
536
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
537
                return (model_info, arch)
538

539
        return self._raise_for_unsupported(architectures)
540

541
542
    def resolve_model_cls(
        self,
543
544
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
545
        architectures = self._normalize_archs(architectures)
546

547
548
549
550
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
551

552
        return self._raise_for_unsupported(architectures)
553

554
555
    def is_text_generation_model(
        self,
556
        architectures: Union[str, list[str]],
557
    ) -> bool:
558
559
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
560

561
    def is_pooling_model(
562
        self,
563
        architectures: Union[str, list[str]],
564
    ) -> bool:
565
        model_cls, _ = self.inspect_model_cls(architectures)
566
        return model_cls.is_pooling_model
567

568
569
    def is_cross_encoder_model(
        self,
570
        architectures: Union[str, list[str]],
571
    ) -> bool:
572
573
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
574

575
576
    def is_multimodal_model(
        self,
577
        architectures: Union[str, list[str]],
578
    ) -> bool:
579
580
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
581

582
583
584
585
586
587
588
    def supports_multimodal_raw_input(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal_raw_input

589
590
    def is_pp_supported_model(
        self,
591
        architectures: Union[str, list[str]],
592
    ) -> bool:
593
594
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
595

596
597
    def model_has_inner_state(
        self,
598
        architectures: Union[str, list[str]],
599
600
601
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
602

603
604
    def is_attention_free_model(
        self,
605
        architectures: Union[str, list[str]],
606
607
608
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
609

610
611
    def is_hybrid_model(
        self,
612
        architectures: Union[str, list[str]],
613
614
615
616
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

617
618
    def is_noops_model(
        self,
619
        architectures: Union[str, list[str]],
620
621
622
623
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

624
625
    def is_transcription_model(
        self,
626
        architectures: Union[str, list[str]],
627
628
629
630
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

631
632
633
634
635
636
637
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription_only

638
639
    def is_v1_compatible(
        self,
640
        architectures: Union[str, list[str]],
641
642
643
644
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

645
646

ModelRegistry = _ModelRegistry({
647
648
    model_arch:
    _LazyRegisteredModel(
649
650
651
652
653
654
655
656
657
658
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
659
660
661
662
663
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

664
        # `cloudpickle` allows pickling lambda functions directly
665
        import cloudpickle
666
        input_bytes = cloudpickle.dumps((fn, output_filepath))
667
668
669

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
670
671
672
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
673
674
675
676
677
678
679
680
681

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

682
        with open(output_filepath, "rb") as f:
683
684
685
686
687
688
689
690
691
692
693
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
694
695
696

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
697
698
699


if __name__ == "__main__":
700
    _run()