registry.py 26.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
61
62
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
63
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
64
65
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
66
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
67
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
68
69
70
71
72
73
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
74
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
75
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
76
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
77
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
78
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
79
80
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
81
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
82
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
83
84
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
89
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
90
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
91
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
92
93
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
94
95
96
97
98
99
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
100
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
101
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
102
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
103
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
104
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
105
106
107
108
109
110
111
112
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
113
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
114
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
115
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
116
117
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
118
119
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
120
121
122
123
124
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
125
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
126
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
127
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
128
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
129
130
131
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
132
133
134
}

_EMBEDDING_MODELS = {
135
    # [Text-only]
136
    "BertModel": ("bert", "BertEmbeddingModel"),
137
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
138
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
139
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
140
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
141
    "GritLM": ("gritlm", "GritLM"),
142
143
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
144
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
145
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
146
    "LlamaModel": ("llama", "LlamaForCausalLM"),
147
148
149
150
151
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
152
    "MistralModel": ("llama", "LlamaForCausalLM"),
153
    "ModernBertModel": ("modernbert", "ModernBertModel"),
154
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
155
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
156
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
157
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
158
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
159
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
160
161
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
162
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
163
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
164
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
165
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
166
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
167
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
168
169
170
171
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
172
173
}

174
175
176
177
178
179
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
180
181
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
182
    # [Auto-converted (see adapters.py)]
183
    "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
184
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
185
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
186
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
187
188
}

189
_MULTIMODAL_MODELS = {
190
    # [Decoder-only]
191
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
192
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
193
194
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
195
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
196
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
197
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
198
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
199
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
200
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
201
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
202
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
203
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
204
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
205
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
206
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
207
208
209
210
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
211
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
212
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
213
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
214
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
215
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
216
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
217
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
218
    "Ovis": ("ovis", "Ovis"),
219
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
220
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
221
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
222
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
223
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
224
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
225
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
226
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
227
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
228
    "UltravoxModel": ("ultravox", "UltravoxModel"),
229
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
230
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
231
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
232
    # [Encoder-decoder]
233
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
234
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
235
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
236
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
237
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
238
}
239
240

_SPECULATIVE_DECODING_MODELS = {
241
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
242
    "EAGLEModel": ("eagle", "EAGLE"),
243
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
244
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
245
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
246
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
247
248
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
249
}
250

251
_TRANSFORMERS_MODELS = {
252
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
253
}
254
# yapf: enable
255

256
_VLLM_MODELS = {
257
    **_TEXT_GENERATION_MODELS,
258
    **_EMBEDDING_MODELS,
259
    **_CROSS_ENCODER_MODELS,
260
    **_MULTIMODAL_MODELS,
261
    **_SPECULATIVE_DECODING_MODELS,
262
    **_TRANSFORMERS_MODELS,
263
264
}

265
266
267
268
269
270
271
272
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

273

274
275
@dataclass(frozen=True)
class _ModelInfo:
276
    architecture: str
277
    is_text_generation_model: bool
278
    is_pooling_model: bool
279
    supports_cross_encoding: bool
280
281
    supports_multimodal: bool
    supports_pp: bool
282
283
    has_inner_state: bool
    is_attention_free: bool
284
    is_hybrid: bool
285
    has_noops: bool
286
    supports_transcription: bool
287
    supports_transcription_only: bool
288
    supports_v0_only: bool
289
290

    @staticmethod
291
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
292
        return _ModelInfo(
293
            architecture=model.__name__,
294
            is_text_generation_model=is_text_generation_model(model),
295
            is_pooling_model=True,  # Can convert any model into a pooling model
296
            supports_cross_encoding=supports_cross_encoding(model),
297
298
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
299
300
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
301
            is_hybrid=is_hybrid(model),
302
            supports_transcription=supports_transcription(model),
303
304
            supports_transcription_only=(supports_transcription(model) and
                                         model.supports_transcription_only),
305
            supports_v0_only=supports_v0_only(model),
306
            has_noops=has_noops(model),
307
        )
308
309


310
class _BaseRegisteredModel(ABC):
311

312
313
314
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
315

316
    @abstractmethod
317
    def load_model_cls(self) -> type[nn.Module]:
318
        raise NotImplementedError
319
320


321
322
323
324
325
326
327
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
328
    model_cls: type[nn.Module]
329
330

    @staticmethod
331
    def from_model_cls(model_cls: type[nn.Module]):
332
333
334
335
336
337
338
339
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

340
    def load_model_cls(self) -> type[nn.Module]:
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

357
    def load_model_cls(self) -> type[nn.Module]:
358
359
360
361
362
363
364
365
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
366
) -> Optional[type[nn.Module]]:
367
    from vllm.platforms import current_platform
368
    current_platform.verify_model_arch(model_arch)
369
370
371
372
373
374
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
375
376


377
378
379
380
381
382
383
384
385
386
387
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
388
389


390
391
392
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
393
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
394

395
    def get_supported_archs(self) -> Set[str]:
396
        return self.models.keys()
397

398
399
400
    def register_model(
        self,
        model_arch: str,
401
        model_cls: Union[type[nn.Module], str],
402
    ) -> None:
403
404
405
        """
        Register an external model to be used in vLLM.

406
        `model_cls` can be either:
407

408
        - A [`torch.nn.Module`][] class directly referencing the model.
409
        - A string in the format `<module>:<class>` which can be used to
410
411
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
412
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
413
        """
414
415
416
417
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

418
        if model_arch in self.models:
419
420
421
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
422
423
424
425
426
427
428
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
429

430
            model = _LazyRegisteredModel(*split_str)
431
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
432
            model = _RegisteredModel.from_model_cls(model_cls)
433
434
435
436
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
437

438
        self.models[model_arch] = model
439

440
    def _raise_for_unsupported(self, architectures: list[str]):
441
        all_supported_archs = self.get_supported_archs()
442

443
444
445
446
447
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

448
449
450
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
451

452
    def _try_load_model_cls(self,
453
                            model_arch: str) -> Optional[type[nn.Module]]:
454
455
        if model_arch not in self.models:
            return None
456

457
        return _try_load_model_cls(model_arch, self.models[model_arch])
458

459
460
461
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
462

463
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
464

465
466
    def _normalize_archs(
        self,
467
468
        architectures: Union[str, list[str]],
    ) -> list[str]:
469
470
471
472
473
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

474
475
476
477
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

478
        # make sure Transformers backend is put at the last as a fallback
479
        if len(normalized_arch) != len(architectures):
480
            normalized_arch.append("TransformersForCausalLM")
481
        return normalized_arch
482

483
484
    def inspect_model_cls(
        self,
485
486
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
487
        architectures = self._normalize_archs(architectures)
488

489
490
491
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
492
                return (model_info, arch)
493

494
        return self._raise_for_unsupported(architectures)
495

496
497
    def resolve_model_cls(
        self,
498
499
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
500
        architectures = self._normalize_archs(architectures)
501

502
503
504
505
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
506

507
        return self._raise_for_unsupported(architectures)
508

509
510
    def is_text_generation_model(
        self,
511
        architectures: Union[str, list[str]],
512
    ) -> bool:
513
514
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
515

516
    def is_pooling_model(
517
        self,
518
        architectures: Union[str, list[str]],
519
    ) -> bool:
520
        model_cls, _ = self.inspect_model_cls(architectures)
521
        return model_cls.is_pooling_model
522

523
524
    def is_cross_encoder_model(
        self,
525
        architectures: Union[str, list[str]],
526
    ) -> bool:
527
528
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
529

530
531
    def is_multimodal_model(
        self,
532
        architectures: Union[str, list[str]],
533
    ) -> bool:
534
535
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
536
537
538

    def is_pp_supported_model(
        self,
539
        architectures: Union[str, list[str]],
540
    ) -> bool:
541
542
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
543

544
545
    def model_has_inner_state(
        self,
546
        architectures: Union[str, list[str]],
547
548
549
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
550

551
552
    def is_attention_free_model(
        self,
553
        architectures: Union[str, list[str]],
554
555
556
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
557

558
559
    def is_hybrid_model(
        self,
560
        architectures: Union[str, list[str]],
561
562
563
564
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

565
566
    def is_noops_model(
        self,
567
        architectures: Union[str, list[str]],
568
569
570
571
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

572
573
    def is_transcription_model(
        self,
574
        architectures: Union[str, list[str]],
575
576
577
578
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

579
580
581
582
583
584
585
    def is_transcription_only_model(
        self,
        architectures: Union[str, list[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription_only

586
587
    def is_v1_compatible(
        self,
588
        architectures: Union[str, list[str]],
589
590
591
592
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

593
594

ModelRegistry = _ModelRegistry({
595
596
    model_arch:
    _LazyRegisteredModel(
597
598
599
600
601
602
603
604
605
606
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
607
608
609
610
611
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

612
        # `cloudpickle` allows pickling lambda functions directly
613
        import cloudpickle
614
        input_bytes = cloudpickle.dumps((fn, output_filepath))
615
616
617

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
618
619
620
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
621
622
623
624
625
626
627
628
629

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

630
        with open(output_filepath, "rb") as f:
631
632
633
634
635
636
637
638
639
640
641
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
642
643
644

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
645
646
647


if __name__ == "__main__":
648
    _run()