registry.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
40
41
42
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
43
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
44
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
45
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
46
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
47
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
50
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
51
52
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
53
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
54
55
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
56
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
57
58
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
59
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
60
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
61
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
62
63
64
65
66
67
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
68
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
69
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
70
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
71
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
72
73
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
74
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
75
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
76
77
78
79
80
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
81
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
82
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
83
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
84
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
85
86
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
87
88
89
90
91
92
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
93
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
94
95
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
96
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
97
98
99
100
101
102
103
104
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
105
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
106
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
107
108
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
109
110
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
111
112
113
114
115
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
116
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
117
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
118
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
119
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
120
121
122
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
123
124
125
}

_EMBEDDING_MODELS = {
126
    # [Text-only]
127
    "BertModel": ("bert", "BertEmbeddingModel"),
128
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
129
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
130
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
131
    "GritLM": ("gritlm", "GritLM"),
132
133
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
134
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
135
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
136
    "LlamaModel": ("llama", "LlamaForCausalLM"),
137
138
139
140
141
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
142
    "MistralModel": ("llama", "LlamaForCausalLM"),
143
    "ModernBertModel": ("modernbert", "ModernBertModel"),
144
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
145
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
146
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
147
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
148
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
149
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
150
151
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
152
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
153
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
154
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
155
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
156
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
157
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
158
159
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
160
161
162
163
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
164
165
}

166
167
168
169
170
171
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
172
173
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
174
175
}

176
_MULTIMODAL_MODELS = {
177
    # [Decoder-only]
178
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
179
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
180
181
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
182
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
183
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
184
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
185
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
186
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
187
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
188
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
189
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
190
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
191
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
192
193
194
195
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
196
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
197
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
198
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
199
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
200
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
201
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
202
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
203
    "Ovis": ("ovis", "Ovis"),
204
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
205
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
206
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
207
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
208
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
209
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
210
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
211
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
212
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
213
    "UltravoxModel": ("ultravox", "UltravoxModel"),
214
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
215
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
216
    # [Encoder-decoder]
217
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
218
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
219
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
220
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
221
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
222
}
223
224

_SPECULATIVE_DECODING_MODELS = {
225
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
226
    "EAGLEModel": ("eagle", "EAGLE"),
227
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
228
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
229
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
230
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
231
232
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
233
}
234

235
_TRANSFORMERS_MODELS = {
236
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
237
}
238
# yapf: enable
239

240
_VLLM_MODELS = {
241
    **_TEXT_GENERATION_MODELS,
242
    **_EMBEDDING_MODELS,
243
    **_CROSS_ENCODER_MODELS,
244
    **_MULTIMODAL_MODELS,
245
    **_SPECULATIVE_DECODING_MODELS,
246
    **_TRANSFORMERS_MODELS,
247
248
}

249
250
251
252
253
254
255
256
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

257

258
259
@dataclass(frozen=True)
class _ModelInfo:
260
    architecture: str
261
    is_text_generation_model: bool
262
    is_pooling_model: bool
263
    supports_cross_encoding: bool
264
265
    supports_multimodal: bool
    supports_pp: bool
266
267
    has_inner_state: bool
    is_attention_free: bool
268
    is_hybrid: bool
269
    has_noops: bool
270
    supports_transcription: bool
271
    supports_v0_only: bool
272
273

    @staticmethod
274
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
275
        return _ModelInfo(
276
            architecture=model.__name__,
277
            is_text_generation_model=is_text_generation_model(model),
278
            is_pooling_model=True,  # Can convert any model into a pooling model
279
            supports_cross_encoding=supports_cross_encoding(model),
280
281
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
282
283
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
284
            is_hybrid=is_hybrid(model),
285
286
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
287
            has_noops=has_noops(model),
288
        )
289
290


291
class _BaseRegisteredModel(ABC):
292

293
294
295
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
296

297
    @abstractmethod
298
    def load_model_cls(self) -> type[nn.Module]:
299
        raise NotImplementedError
300
301


302
303
304
305
306
307
308
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
309
    model_cls: type[nn.Module]
310
311

    @staticmethod
312
    def from_model_cls(model_cls: type[nn.Module]):
313
314
315
316
317
318
319
320
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

321
    def load_model_cls(self) -> type[nn.Module]:
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

338
    def load_model_cls(self) -> type[nn.Module]:
339
340
341
342
343
344
345
346
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
347
) -> Optional[type[nn.Module]]:
348
    from vllm.platforms import current_platform
349
    current_platform.verify_model_arch(model_arch)
350
351
352
353
354
355
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
356
357


358
359
360
361
362
363
364
365
366
367
368
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
369
370


371
372
373
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
374
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
375

376
    def get_supported_archs(self) -> Set[str]:
377
        return self.models.keys()
378

379
380
381
    def register_model(
        self,
        model_arch: str,
382
        model_cls: Union[type[nn.Module], str],
383
    ) -> None:
384
385
386
        """
        Register an external model to be used in vLLM.

387
        `model_cls` can be either:
388

389
        - A [`torch.nn.Module`][] class directly referencing the model.
390
        - A string in the format `<module>:<class>` which can be used to
391
392
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
393
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
394
        """
395
396
397
398
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

399
        if model_arch in self.models:
400
401
402
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
403
404
405
406
407
408
409
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
410

411
            model = _LazyRegisteredModel(*split_str)
412
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
413
            model = _RegisteredModel.from_model_cls(model_cls)
414
415
416
417
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
418

419
        self.models[model_arch] = model
420

421
    def _raise_for_unsupported(self, architectures: list[str]):
422
        all_supported_archs = self.get_supported_archs()
423

424
425
426
427
428
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

429
430
431
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
432

433
    def _try_load_model_cls(self,
434
                            model_arch: str) -> Optional[type[nn.Module]]:
435
436
        if model_arch not in self.models:
            return None
437

438
        return _try_load_model_cls(model_arch, self.models[model_arch])
439

440
441
442
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
443

444
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
445

446
447
    def _normalize_archs(
        self,
448
449
        architectures: Union[str, list[str]],
    ) -> list[str]:
450
451
452
453
454
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

455
456
457
458
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

459
        # make sure Transformers backend is put at the last as a fallback
460
        if len(normalized_arch) != len(architectures):
461
            normalized_arch.append("TransformersForCausalLM")
462
        return normalized_arch
463

464
465
    def inspect_model_cls(
        self,
466
467
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
468
        architectures = self._normalize_archs(architectures)
469

470
471
472
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
473
                return (model_info, arch)
474

475
        return self._raise_for_unsupported(architectures)
476

477
478
    def resolve_model_cls(
        self,
479
480
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
481
        architectures = self._normalize_archs(architectures)
482

483
484
485
486
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
487

488
        return self._raise_for_unsupported(architectures)
489

490
491
    def is_text_generation_model(
        self,
492
        architectures: Union[str, list[str]],
493
    ) -> bool:
494
495
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
496

497
    def is_pooling_model(
498
        self,
499
        architectures: Union[str, list[str]],
500
    ) -> bool:
501
        model_cls, _ = self.inspect_model_cls(architectures)
502
        return model_cls.is_pooling_model
503

504
505
    def is_cross_encoder_model(
        self,
506
        architectures: Union[str, list[str]],
507
    ) -> bool:
508
509
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
510

511
512
    def is_multimodal_model(
        self,
513
        architectures: Union[str, list[str]],
514
    ) -> bool:
515
516
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
517
518
519

    def is_pp_supported_model(
        self,
520
        architectures: Union[str, list[str]],
521
    ) -> bool:
522
523
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
524

525
526
    def model_has_inner_state(
        self,
527
        architectures: Union[str, list[str]],
528
529
530
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
531

532
533
    def is_attention_free_model(
        self,
534
        architectures: Union[str, list[str]],
535
536
537
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
538

539
540
    def is_hybrid_model(
        self,
541
        architectures: Union[str, list[str]],
542
543
544
545
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

546
547
    def is_noops_model(
        self,
548
        architectures: Union[str, list[str]],
549
550
551
552
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

553
554
    def is_transcription_model(
        self,
555
        architectures: Union[str, list[str]],
556
557
558
559
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

560
561
    def is_v1_compatible(
        self,
562
        architectures: Union[str, list[str]],
563
564
565
566
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

567
568

ModelRegistry = _ModelRegistry({
569
570
    model_arch:
    _LazyRegisteredModel(
571
572
573
574
575
576
577
578
579
580
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
581
582
583
584
585
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

586
        # `cloudpickle` allows pickling lambda functions directly
587
        input_bytes = cloudpickle.dumps((fn, output_filepath))
588
589
590

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
591
592
593
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
594
595
596
597
598
599
600
601
602

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

603
        with open(output_filepath, "rb") as f:
604
605
606
607
608
609
610
611
612
613
614
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
615
616
617

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
618
619
620


if __name__ == "__main__":
621
    _run()