registry.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import asdict, dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
36
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
37
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
41
42
43
44
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
45
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
46
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
47
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
48
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
49
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
50
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
51
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
52
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
53
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
54
55
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
56
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
57
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
58
59
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
60
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
61
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
62
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
63
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
64
65
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
66
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
67
68
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
69
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
70
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
71
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
72
73
74
75
76
77
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
78
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
79
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
80
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
81
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
82
83
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
84
85
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
86
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
87
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
88
89
90
91
92
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
93
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
94
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
95
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
96
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
97
98
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
99
100
101
102
103
104
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
105
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
106
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
107
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
108
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
109
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
110
111
112
113
114
115
116
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
117
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
118
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
119
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
120
121
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
122
123
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
124
125
126
127
128
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
129
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
130
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
131
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
132
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
133
134
135
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
136
137
138
}

_EMBEDDING_MODELS = {
139
    # [Text-only]
140
    "BertModel": ("bert", "BertEmbeddingModel"),
141
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
142
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
143
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
144
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
145
    "GritLM": ("gritlm", "GritLM"),
146
147
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
148
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
149
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
150
    "LlamaModel": ("llama", "LlamaForCausalLM"),
151
152
153
154
155
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
156
    "MistralModel": ("llama", "LlamaForCausalLM"),
157
    "ModernBertModel": ("modernbert", "ModernBertModel"),
158
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
159
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
160
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
161
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
162
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
163
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
164
165
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
166
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
167
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
168
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
169
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
170
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
171
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
172
173
174
175
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
176
177
}

178
179
180
181
182
183
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
184
185
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
186
    # [Auto-converted (see adapters.py)]
187
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
188
189
}

190
_MULTIMODAL_MODELS = {
191
    # [Decoder-only]
192
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
193
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
194
195
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
196
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
197
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
198
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
199
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
200
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
201
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
202
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
203
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
204
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
205
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
206
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
207
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
208
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
209
210
211
212
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
213
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
214
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
215
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
216
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
217
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
218
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
219
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
220
    "Ovis": ("ovis", "Ovis"),
221
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
222
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
223
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
224
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
225
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
226
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
227
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
228
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
229
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
230
    "UltravoxModel": ("ultravox", "UltravoxModel"),
231
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
232
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
233
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
234
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
235
    # [Encoder-decoder]
236
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
237
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
238
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
239
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
240
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
241
}
242
243

_SPECULATIVE_DECODING_MODELS = {
244
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
245
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
246
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
247
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
248
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
249
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
250
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
251
    "MedusaModel": ("medusa", "Medusa"),
252
253
254
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
255
}
256

257
_TRANSFORMERS_MODELS = {
258
    "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
259
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
260
}
261
# yapf: enable
262

263
_VLLM_MODELS = {
264
    **_TEXT_GENERATION_MODELS,
265
    **_EMBEDDING_MODELS,
266
    **_CROSS_ENCODER_MODELS,
267
    **_MULTIMODAL_MODELS,
268
    **_SPECULATIVE_DECODING_MODELS,
269
    **_TRANSFORMERS_MODELS,
270
271
}

272
273
274
275
276
277
278
279
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

280
281
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}

282

283
284
@dataclass(frozen=True)
class _ModelInfo:
285
    architecture: str
286
    is_text_generation_model: bool
287
    is_pooling_model: bool
288
    supports_cross_encoding: bool
289
290
    supports_multimodal: bool
    supports_pp: bool
291
292
    has_inner_state: bool
    is_attention_free: bool
293
    is_hybrid: bool
294
    has_noops: bool
295
    supports_transcription: bool
296
    supports_transcription_only: bool
297
    supports_v0_only: bool
298
299

    @staticmethod
300
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
301
        return _ModelInfo(
302
            architecture=model.__name__,
303
            is_text_generation_model=is_text_generation_model(model),
304
            is_pooling_model=True,  # Can convert any model into a pooling model
305
            supports_cross_encoding=supports_cross_encoding(model),
306
307
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
308
309
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
310
            is_hybrid=is_hybrid(model),
311
            supports_transcription=supports_transcription(model),
312
313
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
314
            supports_v0_only=supports_v0_only(model),
315
            has_noops=has_noops(model),
316
        )
317
318


319
class _BaseRegisteredModel(ABC):
320

321
322
323
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
324

325
    @abstractmethod
326
    def load_model_cls(self) -> type[nn.Module]:
327
        raise NotImplementedError
328
329


330
331
332
333
334
335
336
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
337
    model_cls: type[nn.Module]
338
339

    @staticmethod
340
    def from_model_cls(model_cls: type[nn.Module]):
341
342
343
344
345
346
347
348
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

349
    def load_model_cls(self) -> type[nn.Module]:
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

366
    def load_model_cls(self) -> type[nn.Module]:
367
368
369
370
371
372
373
374
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
375
) -> Optional[type[nn.Module]]:
376
    from vllm.platforms import current_platform
377
    current_platform.verify_model_arch(model_arch)
378
379
380
381
382
383
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
384
385


386
387
388
389
390
391
392
393
394
395
396
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
397
398


399
400
401
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
402
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
403

404
    def get_supported_archs(self) -> Set[str]:
405
        return self.models.keys()
406

407
408
409
    def register_model(
        self,
        model_arch: str,
410
        model_cls: Union[type[nn.Module], str],
411
    ) -> None:
412
413
414
        """
        Register an external model to be used in vLLM.

415
        `model_cls` can be either:
416

417
        - A [`torch.nn.Module`][] class directly referencing the model.
418
        - A string in the format `<module>:<class>` which can be used to
419
420
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
421
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
422
        """
423
424
425
426
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

427
        if model_arch in self.models:
428
429
430
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
431
432
433
434
435
436
437
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
438

439
            model = _LazyRegisteredModel(*split_str)
440
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
441
            model = _RegisteredModel.from_model_cls(model_cls)
442
443
444
445
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
446

447
        self.models[model_arch] = model
448

449
    def _raise_for_unsupported(self, architectures: list[str]):
450
        all_supported_archs = self.get_supported_archs()
451

452
453
454
455
456
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

457
458
459
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
460

461
    def _try_load_model_cls(self,
462
                            model_arch: str) -> Optional[type[nn.Module]]:
463
464
        if model_arch not in self.models:
            return None
465

466
        return _try_load_model_cls(model_arch, self.models[model_arch])
467

468
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
469
470
471
472
473
474
475
476
        if model_arch in self.models:
            return _try_inspect_model_cls(model_arch, self.models[model_arch])

        if model_arch.endswith("ForSequenceClassification"):
            causal_lm_arch = model_arch.replace("ForSequenceClassification",
                                                "ForCausalLM")
            if causal_lm_arch not in self.models:
                return None
477

478
479
480
481
482
483
484
485
486
487
488
            info = _try_inspect_model_cls(causal_lm_arch,
                                          self.models[causal_lm_arch])

            info = _ModelInfo(**dict(
                asdict(info), **{
                    "architecture": model_arch,
                    "supports_cross_encoding": True
                }))
            return info

        return None
489

490
491
    def _normalize_archs(
        self,
492
493
        architectures: Union[str, list[str]],
    ) -> list[str]:
494
495
496
497
498
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

499
500
501
502
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

503
504
505
506
507
508
509
510
511
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            if causal_lm_arch in self.models:
                normalized_arch.append(arch)

512
513
514
515
516
517
518
519
        # NOTE(Isotr0py): Be careful of architectures' order!
        # Make sure Transformers backend architecture is at the end of the
        # list, otherwise pooling models automatic conversion will fail!
        for arch in normalized_arch:
            if arch.startswith("TransformersFor"):
                normalized_arch.remove(arch)
                normalized_arch.append(arch)

520
        return normalized_arch
521

522
523
    def inspect_model_cls(
        self,
524
525
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
526
        architectures = self._normalize_archs(architectures)
527

528
529
530
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
531
                return (model_info, arch)
532

533
        return self._raise_for_unsupported(architectures)
534

535
536
    def resolve_model_cls(
        self,
537
538
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
539
        architectures = self._normalize_archs(architectures)
540

541
542
543
544
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
545

546
        return self._raise_for_unsupported(architectures)
547

548
549
    def is_text_generation_model(
        self,
550
        architectures: Union[str, list[str]],
551
    ) -> bool:
552
553
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
554

555
    def is_pooling_model(
556
        self,
557
        architectures: Union[str, list[str]],
558
    ) -> bool:
559
        model_cls, _ = self.inspect_model_cls(architectures)
560
        return model_cls.is_pooling_model
561

562
563
    def is_cross_encoder_model(
        self,
564
        architectures: Union[str, list[str]],
565
    ) -> bool:
566
567
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
568

569
570
    def is_multimodal_model(
        self,
571
        architectures: Union[str, list[str]],
572
    ) -> bool:
573
574
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
575
576
577

    def is_pp_supported_model(
        self,
578
        architectures: Union[str, list[str]],
579
    ) -> bool:
580
581
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
582

583
584
    def model_has_inner_state(
        self,
585
        architectures: Union[str, list[str]],
586
587
588
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
589

590
591
    def is_attention_free_model(
        self,
592
        architectures: Union[str, list[str]],
593
594
595
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
596

597
598
    def is_hybrid_model(
        self,
599
        architectures: Union[str, list[str]],
600
601
602
603
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

604
605
    def is_noops_model(
        self,
606
        architectures: Union[str, list[str]],
607
608
609
610
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

611
612
    def is_transcription_model(
        self,
613
        architectures: Union[str, list[str]],
614
615
616
617
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

618
619
620
621
622
623
624
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription_only

625
626
    def is_v1_compatible(
        self,
627
        architectures: Union[str, list[str]],
628
629
630
631
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

632
633

ModelRegistry = _ModelRegistry({
634
635
    model_arch:
    _LazyRegisteredModel(
636
637
638
639
640
641
642
643
644
645
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
646
647
648
649
650
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

651
        # `cloudpickle` allows pickling lambda functions directly
652
        import cloudpickle
653
        input_bytes = cloudpickle.dumps((fn, output_filepath))
654
655
656

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
657
658
659
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
660
661
662
663
664
665
666
667
668

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

669
        with open(output_filepath, "rb") as f:
670
671
672
673
674
675
676
677
678
679
680
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
681
682
683

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
684
685
686


if __name__ == "__main__":
687
    _run()