registry.py 24.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
from abc import ABC, abstractmethod
13
from collections.abc import Set
14
15
from dataclasses import dataclass, field
from functools import lru_cache
16
from typing import Callable, Optional, TypeVar, Union
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
39
40
41
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
42
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
43
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
44
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
45
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
46
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
49
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
50
51
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
52
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
53
54
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
55
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
56
57
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
58
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
59
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
60
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
61
62
63
64
65
66
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
67
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
68
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
69
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
70
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
71
72
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
73
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
74
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
75
76
77
78
79
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
81
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
82
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
83
84
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
85
86
87
88
89
90
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
91
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
92
93
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
94
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
95
96
97
98
99
100
101
102
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
103
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
104
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
105
106
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
107
108
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
109
110
111
112
113
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
114
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
115
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
116
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
117
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
118
119
120
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
121
122
123
}

_EMBEDDING_MODELS = {
124
    # [Text-only]
125
    "BertModel": ("bert", "BertEmbeddingModel"),
126
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
127
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
128
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
129
    "GritLM": ("gritlm", "GritLM"),
130
131
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
132
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
133
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
134
    "LlamaModel": ("llama", "LlamaForCausalLM"),
135
136
137
138
139
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
140
    "MistralModel": ("llama", "LlamaForCausalLM"),
141
    "ModernBertModel": ("modernbert", "ModernBertModel"),
142
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
143
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
144
145
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
146
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
147
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
148
149
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
150
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
151
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
152
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
153
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
154
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
155
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
156
157
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
158
159
160
161
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
162
163
}

164
165
166
167
168
169
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
170
171
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
172
173
}

174
_MULTIMODAL_MODELS = {
175
    # [Decoder-only]
176
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
177
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
178
179
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
180
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
181
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
182
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
183
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
184
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
185
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
186
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
187
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
188
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
189
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
190
191
192
193
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
194
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
195
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
196
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
197
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
198
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
199
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
200
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
201
    "Ovis": ("ovis", "Ovis"),
202
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
203
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
204
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
205
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
206
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
207
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
208
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
209
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
210
    "UltravoxModel": ("ultravox", "UltravoxModel"),
211
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
212
    # [Encoder-decoder]
213
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
214
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
215
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
216
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
217
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
218
}
219
220

_SPECULATIVE_DECODING_MODELS = {
221
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
222
    "EAGLEModel": ("eagle", "EAGLE"),
223
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
224
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
225
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
226
227
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
228
}
229

230
_TRANSFORMERS_MODELS = {
231
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
232
}
233
# yapf: enable
234

235
_VLLM_MODELS = {
236
    **_TEXT_GENERATION_MODELS,
237
    **_EMBEDDING_MODELS,
238
    **_CROSS_ENCODER_MODELS,
239
    **_MULTIMODAL_MODELS,
240
    **_SPECULATIVE_DECODING_MODELS,
241
    **_TRANSFORMERS_MODELS,
242
243
}

244
245
246
247
248
249
250
251
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

252

253
254
@dataclass(frozen=True)
class _ModelInfo:
255
    architecture: str
256
    is_text_generation_model: bool
257
    is_pooling_model: bool
258
    supports_cross_encoding: bool
259
260
    supports_multimodal: bool
    supports_pp: bool
261
262
    has_inner_state: bool
    is_attention_free: bool
263
    is_hybrid: bool
264
    has_noops: bool
265
    supports_transcription: bool
266
    supports_v0_only: bool
267
268

    @staticmethod
269
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
270
        return _ModelInfo(
271
            architecture=model.__name__,
272
            is_text_generation_model=is_text_generation_model(model),
273
            is_pooling_model=True,  # Can convert any model into a pooling model
274
            supports_cross_encoding=supports_cross_encoding(model),
275
276
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
277
278
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
279
            is_hybrid=is_hybrid(model),
280
281
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
282
            has_noops=has_noops(model),
283
        )
284
285


286
class _BaseRegisteredModel(ABC):
287

288
289
290
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
291

292
    @abstractmethod
293
    def load_model_cls(self) -> type[nn.Module]:
294
        raise NotImplementedError
295
296


297
298
299
300
301
302
303
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
304
    model_cls: type[nn.Module]
305
306

    @staticmethod
307
    def from_model_cls(model_cls: type[nn.Module]):
308
309
310
311
312
313
314
315
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

316
    def load_model_cls(self) -> type[nn.Module]:
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

333
    def load_model_cls(self) -> type[nn.Module]:
334
335
336
337
338
339
340
341
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
342
) -> Optional[type[nn.Module]]:
343
    from vllm.platforms import current_platform
344
    current_platform.verify_model_arch(model_arch)
345
346
347
348
349
350
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
351
352


353
354
355
356
357
358
359
360
361
362
363
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
364
365


366
367
368
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
369
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
370

371
    def get_supported_archs(self) -> Set[str]:
372
        return self.models.keys()
373

374
375
376
    def register_model(
        self,
        model_arch: str,
377
        model_cls: Union[type[nn.Module], str],
378
    ) -> None:
379
380
381
        """
        Register an external model to be used in vLLM.

382
        `model_cls` can be either:
383

384
385
        - A {class}`torch.nn.Module` class directly referencing the model.
        - A string in the format `<module>:<class>` which can be used to
386
387
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
388
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
389
        """
390
391
392
393
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

394
        if model_arch in self.models:
395
396
397
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
398
399
400
401
402
403
404
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
405

406
            model = _LazyRegisteredModel(*split_str)
407
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
408
            model = _RegisteredModel.from_model_cls(model_cls)
409
410
411
412
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
413

414
        self.models[model_arch] = model
415

416
    def _raise_for_unsupported(self, architectures: list[str]):
417
        all_supported_archs = self.get_supported_archs()
418

419
420
421
422
423
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

424
425
426
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
427

428
    def _try_load_model_cls(self,
429
                            model_arch: str) -> Optional[type[nn.Module]]:
430
431
        if model_arch not in self.models:
            return None
432

433
        return _try_load_model_cls(model_arch, self.models[model_arch])
434

435
436
437
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
438

439
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
440

441
442
    def _normalize_archs(
        self,
443
444
        architectures: Union[str, list[str]],
    ) -> list[str]:
445
446
447
448
449
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

450
451
452
453
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

454
        # make sure Transformers backend is put at the last as a fallback
455
        if len(normalized_arch) != len(architectures):
456
            normalized_arch.append("TransformersForCausalLM")
457
        return normalized_arch
458

459
460
    def inspect_model_cls(
        self,
461
462
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
463
        architectures = self._normalize_archs(architectures)
464

465
466
467
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
468
                return (model_info, arch)
469

470
        return self._raise_for_unsupported(architectures)
471

472
473
    def resolve_model_cls(
        self,
474
475
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
476
        architectures = self._normalize_archs(architectures)
477

478
479
480
481
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
482

483
        return self._raise_for_unsupported(architectures)
484

485
486
    def is_text_generation_model(
        self,
487
        architectures: Union[str, list[str]],
488
    ) -> bool:
489
490
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
491

492
    def is_pooling_model(
493
        self,
494
        architectures: Union[str, list[str]],
495
    ) -> bool:
496
        model_cls, _ = self.inspect_model_cls(architectures)
497
        return model_cls.is_pooling_model
498

499
500
    def is_cross_encoder_model(
        self,
501
        architectures: Union[str, list[str]],
502
    ) -> bool:
503
504
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
505

506
507
    def is_multimodal_model(
        self,
508
        architectures: Union[str, list[str]],
509
    ) -> bool:
510
511
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
512
513
514

    def is_pp_supported_model(
        self,
515
        architectures: Union[str, list[str]],
516
    ) -> bool:
517
518
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
519

520
521
    def model_has_inner_state(
        self,
522
        architectures: Union[str, list[str]],
523
524
525
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
526

527
528
    def is_attention_free_model(
        self,
529
        architectures: Union[str, list[str]],
530
531
532
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
533

534
535
    def is_hybrid_model(
        self,
536
        architectures: Union[str, list[str]],
537
538
539
540
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

541
542
    def is_noops_model(
        self,
543
        architectures: Union[str, list[str]],
544
545
546
547
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

548
549
    def is_transcription_model(
        self,
550
        architectures: Union[str, list[str]],
551
552
553
554
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

555
556
    def is_v1_compatible(
        self,
557
        architectures: Union[str, list[str]],
558
559
560
561
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

562
563

ModelRegistry = _ModelRegistry({
564
565
    model_arch:
    _LazyRegisteredModel(
566
567
568
569
570
571
572
573
574
575
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
576
577
578
579
580
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

581
        # `cloudpickle` allows pickling lambda functions directly
582
        input_bytes = cloudpickle.dumps((fn, output_filepath))
583
584
585

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
586
587
588
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
589
590
591
592
593
594
595
596
597

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

598
        with open(output_filepath, "rb") as f:
599
600
601
602
603
604
605
606
607
608
609
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
610
611
612

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
613
614
615


if __name__ == "__main__":
616
    _run()