registry.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
40
41
42
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
43
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
44
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
45
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
46
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
47
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
50
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
51
52
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
53
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
54
55
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
56
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
57
58
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
59
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
60
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
61
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
62
63
64
65
66
67
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
68
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
69
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
70
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
71
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
72
73
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
74
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
75
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
76
77
78
79
80
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
81
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
82
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
83
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
84
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
85
86
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
87
88
89
90
91
92
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
93
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
94
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
95
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
96
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
97
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
98
99
100
101
102
103
104
105
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
106
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
107
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
108
109
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
110
111
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
112
113
114
115
116
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
117
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
118
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
119
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
120
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
121
122
123
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
124
125
126
}

_EMBEDDING_MODELS = {
127
    # [Text-only]
128
    "BertModel": ("bert", "BertEmbeddingModel"),
129
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
130
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
131
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
132
    "GritLM": ("gritlm", "GritLM"),
133
134
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
135
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
136
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
137
    "LlamaModel": ("llama", "LlamaForCausalLM"),
138
139
140
141
142
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
143
    "MistralModel": ("llama", "LlamaForCausalLM"),
144
    "ModernBertModel": ("modernbert", "ModernBertModel"),
145
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
146
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
147
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
148
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
149
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
150
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
151
152
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
153
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
154
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
155
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
156
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
157
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
158
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
159
160
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
161
162
163
164
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
165
166
}

167
168
169
170
171
172
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
173
174
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
175
176
}

177
_MULTIMODAL_MODELS = {
178
    # [Decoder-only]
179
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
180
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
181
182
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
183
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
184
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
185
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
186
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
187
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
188
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
189
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
190
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
191
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
192
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
193
194
195
196
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
197
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
198
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
199
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
200
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
201
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
202
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
203
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
204
    "Ovis": ("ovis", "Ovis"),
205
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
206
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
207
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
208
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
209
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
210
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
211
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
212
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
213
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
214
    "UltravoxModel": ("ultravox", "UltravoxModel"),
215
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
216
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
217
    # [Encoder-decoder]
218
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
219
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
220
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
221
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
222
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
223
}
224
225

_SPECULATIVE_DECODING_MODELS = {
226
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
227
    "EAGLEModel": ("eagle", "EAGLE"),
228
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
229
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
230
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
231
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
232
233
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
234
}
235

236
_TRANSFORMERS_MODELS = {
237
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
238
}
239
# yapf: enable
240

241
_VLLM_MODELS = {
242
    **_TEXT_GENERATION_MODELS,
243
    **_EMBEDDING_MODELS,
244
    **_CROSS_ENCODER_MODELS,
245
    **_MULTIMODAL_MODELS,
246
    **_SPECULATIVE_DECODING_MODELS,
247
    **_TRANSFORMERS_MODELS,
248
249
}

250
251
252
253
254
255
256
257
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

258

259
260
@dataclass(frozen=True)
class _ModelInfo:
261
    architecture: str
262
    is_text_generation_model: bool
263
    is_pooling_model: bool
264
    supports_cross_encoding: bool
265
266
    supports_multimodal: bool
    supports_pp: bool
267
268
    has_inner_state: bool
    is_attention_free: bool
269
    is_hybrid: bool
270
    has_noops: bool
271
    supports_transcription: bool
272
    supports_v0_only: bool
273
274

    @staticmethod
275
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
276
        return _ModelInfo(
277
            architecture=model.__name__,
278
            is_text_generation_model=is_text_generation_model(model),
279
            is_pooling_model=True,  # Can convert any model into a pooling model
280
            supports_cross_encoding=supports_cross_encoding(model),
281
282
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
283
284
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
285
            is_hybrid=is_hybrid(model),
286
287
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
288
            has_noops=has_noops(model),
289
        )
290
291


292
class _BaseRegisteredModel(ABC):
293

294
295
296
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
297

298
    @abstractmethod
299
    def load_model_cls(self) -> type[nn.Module]:
300
        raise NotImplementedError
301
302


303
304
305
306
307
308
309
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
310
    model_cls: type[nn.Module]
311
312

    @staticmethod
313
    def from_model_cls(model_cls: type[nn.Module]):
314
315
316
317
318
319
320
321
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

322
    def load_model_cls(self) -> type[nn.Module]:
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

339
    def load_model_cls(self) -> type[nn.Module]:
340
341
342
343
344
345
346
347
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
348
) -> Optional[type[nn.Module]]:
349
    from vllm.platforms import current_platform
350
    current_platform.verify_model_arch(model_arch)
351
352
353
354
355
356
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
357
358


359
360
361
362
363
364
365
366
367
368
369
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
370
371


372
373
374
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
375
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
376

377
    def get_supported_archs(self) -> Set[str]:
378
        return self.models.keys()
379

380
381
382
    def register_model(
        self,
        model_arch: str,
383
        model_cls: Union[type[nn.Module], str],
384
    ) -> None:
385
386
387
        """
        Register an external model to be used in vLLM.

388
        `model_cls` can be either:
389

390
        - A [`torch.nn.Module`][] class directly referencing the model.
391
        - A string in the format `<module>:<class>` which can be used to
392
393
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
394
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
395
        """
396
397
398
399
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

400
        if model_arch in self.models:
401
402
403
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
404
405
406
407
408
409
410
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
411

412
            model = _LazyRegisteredModel(*split_str)
413
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
414
            model = _RegisteredModel.from_model_cls(model_cls)
415
416
417
418
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
419

420
        self.models[model_arch] = model
421

422
    def _raise_for_unsupported(self, architectures: list[str]):
423
        all_supported_archs = self.get_supported_archs()
424

425
426
427
428
429
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

430
431
432
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
433

434
    def _try_load_model_cls(self,
435
                            model_arch: str) -> Optional[type[nn.Module]]:
436
437
        if model_arch not in self.models:
            return None
438

439
        return _try_load_model_cls(model_arch, self.models[model_arch])
440

441
442
443
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
444

445
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
446

447
448
    def _normalize_archs(
        self,
449
450
        architectures: Union[str, list[str]],
    ) -> list[str]:
451
452
453
454
455
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

456
457
458
459
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

460
        # make sure Transformers backend is put at the last as a fallback
461
        if len(normalized_arch) != len(architectures):
462
            normalized_arch.append("TransformersForCausalLM")
463
        return normalized_arch
464

465
466
    def inspect_model_cls(
        self,
467
468
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
469
        architectures = self._normalize_archs(architectures)
470

471
472
473
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
474
                return (model_info, arch)
475

476
        return self._raise_for_unsupported(architectures)
477

478
479
    def resolve_model_cls(
        self,
480
481
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
482
        architectures = self._normalize_archs(architectures)
483

484
485
486
487
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
488

489
        return self._raise_for_unsupported(architectures)
490

491
492
    def is_text_generation_model(
        self,
493
        architectures: Union[str, list[str]],
494
    ) -> bool:
495
496
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
497

498
    def is_pooling_model(
499
        self,
500
        architectures: Union[str, list[str]],
501
    ) -> bool:
502
        model_cls, _ = self.inspect_model_cls(architectures)
503
        return model_cls.is_pooling_model
504

505
506
    def is_cross_encoder_model(
        self,
507
        architectures: Union[str, list[str]],
508
    ) -> bool:
509
510
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
511

512
513
    def is_multimodal_model(
        self,
514
        architectures: Union[str, list[str]],
515
    ) -> bool:
516
517
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
518
519
520

    def is_pp_supported_model(
        self,
521
        architectures: Union[str, list[str]],
522
    ) -> bool:
523
524
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
525

526
527
    def model_has_inner_state(
        self,
528
        architectures: Union[str, list[str]],
529
530
531
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
532

533
534
    def is_attention_free_model(
        self,
535
        architectures: Union[str, list[str]],
536
537
538
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
539

540
541
    def is_hybrid_model(
        self,
542
        architectures: Union[str, list[str]],
543
544
545
546
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

547
548
    def is_noops_model(
        self,
549
        architectures: Union[str, list[str]],
550
551
552
553
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

554
555
    def is_transcription_model(
        self,
556
        architectures: Union[str, list[str]],
557
558
559
560
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

561
562
    def is_v1_compatible(
        self,
563
        architectures: Union[str, list[str]],
564
565
566
567
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

568
569

ModelRegistry = _ModelRegistry({
570
571
    model_arch:
    _LazyRegisteredModel(
572
573
574
575
576
577
578
579
580
581
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
582
583
584
585
586
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

587
        # `cloudpickle` allows pickling lambda functions directly
588
        input_bytes = cloudpickle.dumps((fn, output_filepath))
589
590
591

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
592
593
594
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
595
596
597
598
599
600
601
602
603

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

604
        with open(output_filepath, "rb") as f:
605
606
607
608
609
610
611
612
613
614
615
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
616
617
618

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
619
620
621


if __name__ == "__main__":
622
    _run()