registry.py 28.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
from dataclasses import asdict, dataclass, field
16
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
44
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
45
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
46
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
47
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
48
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
49
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
51
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
52
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
53
54
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
55
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
56
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
57
58
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
60
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
61
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
62
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
63
64
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
65
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
66
67
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
68
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
69
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
70
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
71
72
73
74
75
76
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
77
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
78
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
79
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
80
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
81
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
82
83
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
84
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
85
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
86
87
88
89
90
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
92
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
93
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
94
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
95
96
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
97
98
99
100
101
102
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
103
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
104
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
105
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
106
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
107
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
108
109
110
111
112
113
114
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
115
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
116
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
117
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
118
119
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
120
121
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
122
123
124
125
126
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
127
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
128
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
129
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
130
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
131
132
133
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
134
135
136
}

_EMBEDDING_MODELS = {
137
    # [Text-only]
138
    "BertModel": ("bert", "BertEmbeddingModel"),
139
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
140
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
141
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
142
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
143
    "GritLM": ("gritlm", "GritLM"),
144
145
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
146
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
147
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
148
    "LlamaModel": ("llama", "LlamaForCausalLM"),
149
150
151
152
153
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
154
    "MistralModel": ("llama", "LlamaForCausalLM"),
155
    "ModernBertModel": ("modernbert", "ModernBertModel"),
156
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
157
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
158
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
159
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
160
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
161
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
162
163
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
164
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
165
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
166
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
167
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
168
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
169
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
170
171
172
173
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
174
175
}

176
177
178
179
180
181
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
182
183
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
184
    # [Auto-converted (see adapters.py)]
185
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
186
187
}

188
_MULTIMODAL_MODELS = {
189
    # [Decoder-only]
190
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
191
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
192
193
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
194
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
195
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
196
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
197
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
198
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
199
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
200
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
201
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
202
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
203
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
204
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
205
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
206
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
207
208
209
210
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
211
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
212
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
213
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
214
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
215
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
216
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
217
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
218
    "Ovis": ("ovis", "Ovis"),
219
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
220
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
221
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
222
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
223
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
224
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
225
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
226
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
227
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
228
    "UltravoxModel": ("ultravox", "UltravoxModel"),
229
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
230
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
231
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
Patrick von Platen's avatar
Patrick von Platen committed
232
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
233
    # [Encoder-decoder]
234
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
235
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
236
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
237
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
238
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
239
}
240
241

_SPECULATIVE_DECODING_MODELS = {
242
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
243
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
244
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
245
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
246
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
247
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
248
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
249
    "MedusaModel": ("medusa", "Medusa"),
250
251
252
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
253
}
254

255
_TRANSFORMERS_MODELS = {
256
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
257
}
258
# yapf: enable
259

260
_VLLM_MODELS = {
261
    **_TEXT_GENERATION_MODELS,
262
    **_EMBEDDING_MODELS,
263
    **_CROSS_ENCODER_MODELS,
264
    **_MULTIMODAL_MODELS,
265
    **_SPECULATIVE_DECODING_MODELS,
266
    **_TRANSFORMERS_MODELS,
267
268
}

269
270
271
272
273
274
275
276
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

277

278
279
@dataclass(frozen=True)
class _ModelInfo:
280
    architecture: str
281
    is_text_generation_model: bool
282
    is_pooling_model: bool
283
    supports_cross_encoding: bool
284
285
    supports_multimodal: bool
    supports_pp: bool
286
287
    has_inner_state: bool
    is_attention_free: bool
288
    is_hybrid: bool
289
    has_noops: bool
290
    supports_transcription: bool
291
    supports_transcription_only: bool
292
    supports_v0_only: bool
293
294

    @staticmethod
295
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
296
        return _ModelInfo(
297
            architecture=model.__name__,
298
            is_text_generation_model=is_text_generation_model(model),
299
            is_pooling_model=True,  # Can convert any model into a pooling model
300
            supports_cross_encoding=supports_cross_encoding(model),
301
302
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
303
304
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
305
            is_hybrid=is_hybrid(model),
306
            supports_transcription=supports_transcription(model),
307
308
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
309
            supports_v0_only=supports_v0_only(model),
310
            has_noops=has_noops(model),
311
        )
312
313


314
class _BaseRegisteredModel(ABC):
315

316
317
318
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
319

320
    @abstractmethod
321
    def load_model_cls(self) -> type[nn.Module]:
322
        raise NotImplementedError
323
324


325
326
327
328
329
330
331
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
332
    model_cls: type[nn.Module]
333
334

    @staticmethod
335
    def from_model_cls(model_cls: type[nn.Module]):
336
337
338
339
340
341
342
343
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

344
    def load_model_cls(self) -> type[nn.Module]:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

361
    def load_model_cls(self) -> type[nn.Module]:
362
363
364
365
366
367
368
369
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
370
) -> Optional[type[nn.Module]]:
371
    from vllm.platforms import current_platform
372
    current_platform.verify_model_arch(model_arch)
373
374
375
376
377
378
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
379
380


381
382
383
384
385
386
387
388
389
390
391
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
392
393


394
395
396
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
397
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
398

399
    def get_supported_archs(self) -> Set[str]:
400
        return self.models.keys()
401

402
403
404
    def register_model(
        self,
        model_arch: str,
405
        model_cls: Union[type[nn.Module], str],
406
    ) -> None:
407
408
409
        """
        Register an external model to be used in vLLM.

410
        `model_cls` can be either:
411

412
        - A [`torch.nn.Module`][] class directly referencing the model.
413
        - A string in the format `<module>:<class>` which can be used to
414
415
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
416
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
417
        """
418
419
420
421
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

422
        if model_arch in self.models:
423
424
425
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
426
427
428
429
430
431
432
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
433

434
            model = _LazyRegisteredModel(*split_str)
435
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
436
            model = _RegisteredModel.from_model_cls(model_cls)
437
438
439
440
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
441

442
        self.models[model_arch] = model
443

444
    def _raise_for_unsupported(self, architectures: list[str]):
445
        all_supported_archs = self.get_supported_archs()
446

447
448
449
450
451
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

452
453
454
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
455

456
    def _try_load_model_cls(self,
457
                            model_arch: str) -> Optional[type[nn.Module]]:
458
459
        if model_arch not in self.models:
            return None
460

461
        return _try_load_model_cls(model_arch, self.models[model_arch])
462

463
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
464
465
466
467
468
469
470
471
        if model_arch in self.models:
            return _try_inspect_model_cls(model_arch, self.models[model_arch])

        if model_arch.endswith("ForSequenceClassification"):
            causal_lm_arch = model_arch.replace("ForSequenceClassification",
                                                "ForCausalLM")
            if causal_lm_arch not in self.models:
                return None
472

473
474
475
476
477
478
479
480
481
482
483
            info = _try_inspect_model_cls(causal_lm_arch,
                                          self.models[causal_lm_arch])

            info = _ModelInfo(**dict(
                asdict(info), **{
                    "architecture": model_arch,
                    "supports_cross_encoding": True
                }))
            return info

        return None
484

485
486
    def _normalize_archs(
        self,
487
488
        architectures: Union[str, list[str]],
    ) -> list[str]:
489
490
491
492
493
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

494
495
496
497
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

498
499
500
501
502
503
504
505
506
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            if causal_lm_arch in self.models:
                normalized_arch.append(arch)

507
        # make sure Transformers backend is put at the last as a fallback
508
        if len(normalized_arch) != len(architectures):
509
            normalized_arch.append("TransformersForCausalLM")
510
        return normalized_arch
511

512
513
    def inspect_model_cls(
        self,
514
515
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
516
        architectures = self._normalize_archs(architectures)
517

518
519
520
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
521
                return (model_info, arch)
522

523
        return self._raise_for_unsupported(architectures)
524

525
526
    def resolve_model_cls(
        self,
527
528
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
529
        architectures = self._normalize_archs(architectures)
530

531
532
533
534
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
535

536
        return self._raise_for_unsupported(architectures)
537

538
539
    def is_text_generation_model(
        self,
540
        architectures: Union[str, list[str]],
541
    ) -> bool:
542
543
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
544

545
    def is_pooling_model(
546
        self,
547
        architectures: Union[str, list[str]],
548
    ) -> bool:
549
        model_cls, _ = self.inspect_model_cls(architectures)
550
        return model_cls.is_pooling_model
551

552
553
    def is_cross_encoder_model(
        self,
554
        architectures: Union[str, list[str]],
555
    ) -> bool:
556
557
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
558

559
560
    def is_multimodal_model(
        self,
561
        architectures: Union[str, list[str]],
562
    ) -> bool:
563
564
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
565
566
567

    def is_pp_supported_model(
        self,
568
        architectures: Union[str, list[str]],
569
    ) -> bool:
570
571
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
572

573
574
    def model_has_inner_state(
        self,
575
        architectures: Union[str, list[str]],
576
577
578
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
579

580
581
    def is_attention_free_model(
        self,
582
        architectures: Union[str, list[str]],
583
584
585
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
586

587
588
    def is_hybrid_model(
        self,
589
        architectures: Union[str, list[str]],
590
591
592
593
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

594
595
    def is_noops_model(
        self,
596
        architectures: Union[str, list[str]],
597
598
599
600
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

601
602
    def is_transcription_model(
        self,
603
        architectures: Union[str, list[str]],
604
605
606
607
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

608
609
610
611
612
613
614
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription_only

615
616
    def is_v1_compatible(
        self,
617
        architectures: Union[str, list[str]],
618
619
620
621
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

622
623

ModelRegistry = _ModelRegistry({
624
625
    model_arch:
    _LazyRegisteredModel(
626
627
628
629
630
631
632
633
634
635
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
636
637
638
639
640
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

641
        # `cloudpickle` allows pickling lambda functions directly
642
        import cloudpickle
643
        input_bytes = cloudpickle.dumps((fn, output_filepath))
644
645
646

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
647
648
649
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
650
651
652
653
654
655
656
657
658

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

659
        with open(output_filepath, "rb") as f:
660
661
662
663
664
665
666
667
668
669
670
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
671
672
673

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
674
675
676


if __name__ == "__main__":
677
    _run()