registry.py 24.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
13
14
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
15
16
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
17

18
import cloudpickle
19
20
21
import torch.nn as nn

from vllm.logger import init_logger
22
from vllm.utils import is_in_doc_build
23

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
40
41
42
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
43
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
44
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
45
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
46
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
47
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
50
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
51
52
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
53
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
54
55
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
56
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
57
58
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
59
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
60
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
61
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
62
63
64
65
66
67
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
68
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
69
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
70
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
71
72
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
73
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
74
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
75
76
77
78
79
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
81
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
82
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
83
84
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
85
86
87
88
89
90
91
92
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
93
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
94
95
96
97
98
99
100
101
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
102
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
103
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
104
105
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
106
107
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
108
109
110
111
112
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
113
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
114
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
115
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
116
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
117
118
119
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
120
121
122
}

_EMBEDDING_MODELS = {
123
    # [Text-only]
124
    "BertModel": ("bert", "BertEmbeddingModel"),
125
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
126
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
127
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
128
    "GritLM": ("gritlm", "GritLM"),
129
    "GteModel": ("bert", "GteEmbeddingModel"),
130
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
131
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
132
    "LlamaModel": ("llama", "LlamaForCausalLM"),
133
134
135
136
137
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
138
    "MistralModel": ("llama", "LlamaForCausalLM"),
139
    "NomicBertModel": ("bert", "NomicBertEmbeddingModel"),
140
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
141
142
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
143
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
144
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
145
146
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
147
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
148
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
149
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
150
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
151
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
152
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
153
154
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
155
156
157
158
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
159
160
}

161
162
163
164
165
166
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
167
168
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
169
170
}

171
_MULTIMODAL_MODELS = {
172
    # [Decoder-only]
173
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
174
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
175
176
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
177
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
178
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
179
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
180
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
181
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
182
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
183
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
184
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
185
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
186
187
188
189
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
190
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
191
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
192
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
193
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
194
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
195
196
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
197
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
198
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
199
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
200
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
201
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
202
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
203
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
204
    "UltravoxModel": ("ultravox", "UltravoxModel"),
205
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
206
    # [Encoder-decoder]
207
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
208
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
209
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
210
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
211
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
212
}
213
214
215

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
216
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
217
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
218
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
219
220
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
221
}
222

223
_TRANSFORMERS_MODELS = {
224
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
225
}
226
# yapf: enable
227

228
_VLLM_MODELS = {
229
    **_TEXT_GENERATION_MODELS,
230
    **_EMBEDDING_MODELS,
231
    **_CROSS_ENCODER_MODELS,
232
    **_MULTIMODAL_MODELS,
233
    **_SPECULATIVE_DECODING_MODELS,
234
    **_TRANSFORMERS_MODELS,
235
236
}

237
238
239
240
241
242
243
244
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

245

246
247
@dataclass(frozen=True)
class _ModelInfo:
248
    architecture: str
249
    is_text_generation_model: bool
250
    is_pooling_model: bool
251
    supports_cross_encoding: bool
252
253
    supports_multimodal: bool
    supports_pp: bool
254
255
    has_inner_state: bool
    is_attention_free: bool
256
    is_hybrid: bool
257
    has_noops: bool
258
    supports_transcription: bool
259
    supports_v0_only: bool
260
261

    @staticmethod
262
263
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
264
            architecture=model.__name__,
265
            is_text_generation_model=is_text_generation_model(model),
266
            is_pooling_model=True,  # Can convert any model into a pooling model
267
            supports_cross_encoding=supports_cross_encoding(model),
268
269
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
270
271
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
272
            is_hybrid=is_hybrid(model),
273
274
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
275
            has_noops=has_noops(model),
276
        )
277
278


279
class _BaseRegisteredModel(ABC):
280

281
282
283
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
284

285
286
287
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
288
289


290
291
292
293
294
295
296
297
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
298
299

    @staticmethod
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
336
    from vllm.platforms import current_platform
337
    current_platform.verify_model_arch(model_arch)
338
339
340
341
342
343
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
344
345


346
347
348
349
350
351
352
353
354
355
356
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
357
358


359
360
361
362
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
363

364
365
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
366

367
368
369
370
371
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
372
373
374
375
376
377
378
379
380
381
382
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
383
384
385
386
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

387
        if model_arch in self.models:
388
389
390
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
391
392
393
394
395
396
397
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
398

399
            model = _LazyRegisteredModel(*split_str)
400
401
        elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass(
                model_cls, nn.Module)):
402
            model = _RegisteredModel.from_model_cls(model_cls)
403
404
405
406
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
407

408
        self.models[model_arch] = model
409

410
411
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
412

413
414
415
416
417
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

418
419
420
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
421

422
423
424
425
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
426

427
        return _try_load_model_cls(model_arch, self.models[model_arch])
428

429
430
431
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
432

433
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
434

435
436
437
438
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
439
440
441
442
443
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

444
445
446
447
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

448
        # make sure Transformers backend is put at the last as a fallback
449
        if len(normalized_arch) != len(architectures):
450
            normalized_arch.append("TransformersForCausalLM")
451
        return normalized_arch
452

453
454
455
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
456
    ) -> Tuple[_ModelInfo, str]:
457
        architectures = self._normalize_archs(architectures)
458

459
460
461
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
462
                return (model_info, arch)
463

464
        return self._raise_for_unsupported(architectures)
465

466
467
468
469
470
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
471

472
473
474
475
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
476

477
        return self._raise_for_unsupported(architectures)
478

479
480
481
482
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
483
484
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
485

486
    def is_pooling_model(
487
488
489
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
490
        model_cls, _ = self.inspect_model_cls(architectures)
491
        return model_cls.is_pooling_model
492

493
494
495
496
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
497
498
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
499

500
501
502
503
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
504
505
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
506
507
508
509
510

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
511
512
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
513

514
515
516
517
518
519
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
520

521
522
523
524
525
526
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
527

528
529
530
531
532
533
534
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

535
536
537
538
539
540
541
    def is_noops_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

542
543
544
545
546
547
548
    def is_transcription_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

549
550
551
552
553
554
555
    def is_v1_compatible(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

556
557

ModelRegistry = _ModelRegistry({
558
559
    model_arch:
    _LazyRegisteredModel(
560
561
562
563
564
565
566
567
568
569
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
570
571
572
573
574
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

575
        # `cloudpickle` allows pickling lambda functions directly
576
        input_bytes = cloudpickle.dumps((fn, output_filepath))
577
578
579

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
580
581
582
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
583
584
585
586
587
588
589
590
591

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

592
        with open(output_filepath, "rb") as f:
593
594
595
596
597
598
599
600
601
602
603
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
604
605
606

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
607
608
609


if __name__ == "__main__":
610
    _run()