registry.py 24.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
from abc import ABC, abstractmethod
13
from collections.abc import Set
14
15
from dataclasses import dataclass, field
from functools import lru_cache
16
from typing import Callable, Optional, TypeVar, Union
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
39
40
41
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
42
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
43
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
44
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
45
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
46
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
49
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
50
51
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
52
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
53
54
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
55
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
56
57
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
58
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
59
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
60
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
61
62
63
64
65
66
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
67
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
68
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
69
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
70
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
71
72
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
73
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
74
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
75
76
77
78
79
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
81
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
82
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
83
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
84
85
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
86
87
88
89
90
91
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
92
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
93
94
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
95
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
96
97
98
99
100
101
102
103
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
104
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
105
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
106
107
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
108
109
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
110
111
112
113
114
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
115
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
116
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
117
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
118
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
119
120
121
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
122
123
124
}

_EMBEDDING_MODELS = {
125
    # [Text-only]
126
    "BertModel": ("bert", "BertEmbeddingModel"),
127
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
128
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
129
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
130
    "GritLM": ("gritlm", "GritLM"),
131
132
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
133
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
134
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
135
    "LlamaModel": ("llama", "LlamaForCausalLM"),
136
137
138
139
140
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
141
    "MistralModel": ("llama", "LlamaForCausalLM"),
142
    "ModernBertModel": ("modernbert", "ModernBertModel"),
143
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
144
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
145
146
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
147
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
148
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
149
150
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
151
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
152
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
153
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
154
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
155
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
156
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
157
158
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
159
160
161
162
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
163
164
}

165
166
167
168
169
170
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
171
172
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
173
174
}

175
_MULTIMODAL_MODELS = {
176
    # [Decoder-only]
177
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
178
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
179
180
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
181
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
182
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
183
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
184
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
185
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
186
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
187
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
188
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
189
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
190
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
191
192
193
194
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
195
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
196
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
197
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
198
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
199
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
200
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
201
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
202
    "Ovis": ("ovis", "Ovis"),
203
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
204
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
205
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
206
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
207
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
208
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
209
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
210
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
211
    "UltravoxModel": ("ultravox", "UltravoxModel"),
212
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
213
    # [Encoder-decoder]
214
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
215
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
216
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
217
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
218
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
219
}
220
221

_SPECULATIVE_DECODING_MODELS = {
222
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
223
    "EAGLEModel": ("eagle", "EAGLE"),
224
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
225
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
226
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
227
228
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
229
}
230

231
_TRANSFORMERS_MODELS = {
232
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
233
}
234
# yapf: enable
235

236
_VLLM_MODELS = {
237
    **_TEXT_GENERATION_MODELS,
238
    **_EMBEDDING_MODELS,
239
    **_CROSS_ENCODER_MODELS,
240
    **_MULTIMODAL_MODELS,
241
    **_SPECULATIVE_DECODING_MODELS,
242
    **_TRANSFORMERS_MODELS,
243
244
}

245
246
247
248
249
250
251
252
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

253

254
255
@dataclass(frozen=True)
class _ModelInfo:
256
    architecture: str
257
    is_text_generation_model: bool
258
    is_pooling_model: bool
259
    supports_cross_encoding: bool
260
261
    supports_multimodal: bool
    supports_pp: bool
262
263
    has_inner_state: bool
    is_attention_free: bool
264
    is_hybrid: bool
265
    has_noops: bool
266
    supports_transcription: bool
267
    supports_v0_only: bool
268
269

    @staticmethod
270
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
271
        return _ModelInfo(
272
            architecture=model.__name__,
273
            is_text_generation_model=is_text_generation_model(model),
274
            is_pooling_model=True,  # Can convert any model into a pooling model
275
            supports_cross_encoding=supports_cross_encoding(model),
276
277
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
278
279
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
280
            is_hybrid=is_hybrid(model),
281
282
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
283
            has_noops=has_noops(model),
284
        )
285
286


287
class _BaseRegisteredModel(ABC):
288

289
290
291
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
292

293
    @abstractmethod
294
    def load_model_cls(self) -> type[nn.Module]:
295
        raise NotImplementedError
296
297


298
299
300
301
302
303
304
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
305
    model_cls: type[nn.Module]
306
307

    @staticmethod
308
    def from_model_cls(model_cls: type[nn.Module]):
309
310
311
312
313
314
315
316
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

317
    def load_model_cls(self) -> type[nn.Module]:
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

334
    def load_model_cls(self) -> type[nn.Module]:
335
336
337
338
339
340
341
342
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
343
) -> Optional[type[nn.Module]]:
344
    from vllm.platforms import current_platform
345
    current_platform.verify_model_arch(model_arch)
346
347
348
349
350
351
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
352
353


354
355
356
357
358
359
360
361
362
363
364
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
365
366


367
368
369
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
370
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
371

372
    def get_supported_archs(self) -> Set[str]:
373
        return self.models.keys()
374

375
376
377
    def register_model(
        self,
        model_arch: str,
378
        model_cls: Union[type[nn.Module], str],
379
    ) -> None:
380
381
382
        """
        Register an external model to be used in vLLM.

383
        `model_cls` can be either:
384

385
386
        - A {class}`torch.nn.Module` class directly referencing the model.
        - A string in the format `<module>:<class>` which can be used to
387
388
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
389
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
390
        """
391
392
393
394
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

395
        if model_arch in self.models:
396
397
398
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
399
400
401
402
403
404
405
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
406

407
            model = _LazyRegisteredModel(*split_str)
408
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
409
            model = _RegisteredModel.from_model_cls(model_cls)
410
411
412
413
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
414

415
        self.models[model_arch] = model
416

417
    def _raise_for_unsupported(self, architectures: list[str]):
418
        all_supported_archs = self.get_supported_archs()
419

420
421
422
423
424
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

425
426
427
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
428

429
    def _try_load_model_cls(self,
430
                            model_arch: str) -> Optional[type[nn.Module]]:
431
432
        if model_arch not in self.models:
            return None
433

434
        return _try_load_model_cls(model_arch, self.models[model_arch])
435

436
437
438
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
439

440
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
441

442
443
    def _normalize_archs(
        self,
444
445
        architectures: Union[str, list[str]],
    ) -> list[str]:
446
447
448
449
450
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

451
452
453
454
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

455
        # make sure Transformers backend is put at the last as a fallback
456
        if len(normalized_arch) != len(architectures):
457
            normalized_arch.append("TransformersForCausalLM")
458
        return normalized_arch
459

460
461
    def inspect_model_cls(
        self,
462
463
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
464
        architectures = self._normalize_archs(architectures)
465

466
467
468
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
469
                return (model_info, arch)
470

471
        return self._raise_for_unsupported(architectures)
472

473
474
    def resolve_model_cls(
        self,
475
476
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
477
        architectures = self._normalize_archs(architectures)
478

479
480
481
482
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
483

484
        return self._raise_for_unsupported(architectures)
485

486
487
    def is_text_generation_model(
        self,
488
        architectures: Union[str, list[str]],
489
    ) -> bool:
490
491
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
492

493
    def is_pooling_model(
494
        self,
495
        architectures: Union[str, list[str]],
496
    ) -> bool:
497
        model_cls, _ = self.inspect_model_cls(architectures)
498
        return model_cls.is_pooling_model
499

500
501
    def is_cross_encoder_model(
        self,
502
        architectures: Union[str, list[str]],
503
    ) -> bool:
504
505
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
506

507
508
    def is_multimodal_model(
        self,
509
        architectures: Union[str, list[str]],
510
    ) -> bool:
511
512
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
513
514
515

    def is_pp_supported_model(
        self,
516
        architectures: Union[str, list[str]],
517
    ) -> bool:
518
519
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
520

521
522
    def model_has_inner_state(
        self,
523
        architectures: Union[str, list[str]],
524
525
526
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
527

528
529
    def is_attention_free_model(
        self,
530
        architectures: Union[str, list[str]],
531
532
533
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
534

535
536
    def is_hybrid_model(
        self,
537
        architectures: Union[str, list[str]],
538
539
540
541
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

542
543
    def is_noops_model(
        self,
544
        architectures: Union[str, list[str]],
545
546
547
548
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

549
550
    def is_transcription_model(
        self,
551
        architectures: Union[str, list[str]],
552
553
554
555
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

556
557
    def is_v1_compatible(
        self,
558
        architectures: Union[str, list[str]],
559
560
561
562
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

563
564

ModelRegistry = _ModelRegistry({
565
566
    model_arch:
    _LazyRegisteredModel(
567
568
569
570
571
572
573
574
575
576
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
577
578
579
580
581
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

582
        # `cloudpickle` allows pickling lambda functions directly
583
        input_bytes = cloudpickle.dumps((fn, output_filepath))
584
585
586

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
587
588
589
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
590
591
592
593
594
595
596
597
598

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

599
        with open(output_filepath, "rb") as f:
600
601
602
603
604
605
606
607
608
609
610
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
611
612
613

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
614
615
616


if __name__ == "__main__":
617
    _run()