registry.py 24.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
from abc import ABC, abstractmethod
13
from collections.abc import Set
14
15
from dataclasses import dataclass, field
from functools import lru_cache
16
from typing import Callable, Optional, TypeVar, Union
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
39
40
41
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
42
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
43
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
44
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
45
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
46
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
49
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
50
51
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
52
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
53
54
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
55
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
56
57
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
58
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
59
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
60
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
61
62
63
64
65
66
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
67
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
68
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
69
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
70
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
71
72
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
73
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
74
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
75
76
77
78
79
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
81
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
82
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
83
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
84
85
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
86
87
88
89
90
91
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
92
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
93
94
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
95
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
96
97
98
99
100
101
102
103
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
104
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
105
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
106
107
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
108
109
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
110
111
112
113
114
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
115
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
116
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
117
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
118
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
119
120
121
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
122
123
124
}

_EMBEDDING_MODELS = {
125
    # [Text-only]
126
    "BertModel": ("bert", "BertEmbeddingModel"),
127
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
128
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
129
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
130
    "GritLM": ("gritlm", "GritLM"),
131
132
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
133
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
134
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
135
    "LlamaModel": ("llama", "LlamaForCausalLM"),
136
137
138
139
140
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
141
    "MistralModel": ("llama", "LlamaForCausalLM"),
142
    "ModernBertModel": ("modernbert", "ModernBertModel"),
143
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
144
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
145
146
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
147
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
148
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
149
150
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
151
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
152
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
153
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
154
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
155
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
156
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
157
158
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
159
160
161
162
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
163
164
}

165
166
167
168
169
170
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
171
172
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
173
174
}

175
_MULTIMODAL_MODELS = {
176
    # [Decoder-only]
177
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
178
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
179
180
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
181
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
182
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
183
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
184
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
185
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
186
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
187
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
188
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
189
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
190
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
191
192
193
194
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
195
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
196
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
197
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
198
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
199
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
200
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
201
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
202
    "Ovis": ("ovis", "Ovis"),
203
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
204
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
205
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
206
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
207
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
208
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
209
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
210
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
211
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
212
    "UltravoxModel": ("ultravox", "UltravoxModel"),
213
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
214
    # [Encoder-decoder]
215
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
216
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
217
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
218
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
219
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
220
}
221
222

_SPECULATIVE_DECODING_MODELS = {
223
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
224
    "EAGLEModel": ("eagle", "EAGLE"),
225
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
226
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
227
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
228
229
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
230
}
231

232
_TRANSFORMERS_MODELS = {
233
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
234
}
235
# yapf: enable
236

237
_VLLM_MODELS = {
238
    **_TEXT_GENERATION_MODELS,
239
    **_EMBEDDING_MODELS,
240
    **_CROSS_ENCODER_MODELS,
241
    **_MULTIMODAL_MODELS,
242
    **_SPECULATIVE_DECODING_MODELS,
243
    **_TRANSFORMERS_MODELS,
244
245
}

246
247
248
249
250
251
252
253
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

254

255
256
@dataclass(frozen=True)
class _ModelInfo:
257
    architecture: str
258
    is_text_generation_model: bool
259
    is_pooling_model: bool
260
    supports_cross_encoding: bool
261
262
    supports_multimodal: bool
    supports_pp: bool
263
264
    has_inner_state: bool
    is_attention_free: bool
265
    is_hybrid: bool
266
    has_noops: bool
267
    supports_transcription: bool
268
    supports_v0_only: bool
269
270

    @staticmethod
271
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
272
        return _ModelInfo(
273
            architecture=model.__name__,
274
            is_text_generation_model=is_text_generation_model(model),
275
            is_pooling_model=True,  # Can convert any model into a pooling model
276
            supports_cross_encoding=supports_cross_encoding(model),
277
278
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
279
280
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
281
            is_hybrid=is_hybrid(model),
282
283
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
284
            has_noops=has_noops(model),
285
        )
286
287


288
class _BaseRegisteredModel(ABC):
289

290
291
292
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
293

294
    @abstractmethod
295
    def load_model_cls(self) -> type[nn.Module]:
296
        raise NotImplementedError
297
298


299
300
301
302
303
304
305
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
306
    model_cls: type[nn.Module]
307
308

    @staticmethod
309
    def from_model_cls(model_cls: type[nn.Module]):
310
311
312
313
314
315
316
317
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

318
    def load_model_cls(self) -> type[nn.Module]:
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

335
    def load_model_cls(self) -> type[nn.Module]:
336
337
338
339
340
341
342
343
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
344
) -> Optional[type[nn.Module]]:
345
    from vllm.platforms import current_platform
346
    current_platform.verify_model_arch(model_arch)
347
348
349
350
351
352
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
353
354


355
356
357
358
359
360
361
362
363
364
365
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
366
367


368
369
370
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
371
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
372

373
    def get_supported_archs(self) -> Set[str]:
374
        return self.models.keys()
375

376
377
378
    def register_model(
        self,
        model_arch: str,
379
        model_cls: Union[type[nn.Module], str],
380
    ) -> None:
381
382
383
        """
        Register an external model to be used in vLLM.

384
        `model_cls` can be either:
385

386
387
        - A {class}`torch.nn.Module` class directly referencing the model.
        - A string in the format `<module>:<class>` which can be used to
388
389
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
390
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
391
        """
392
393
394
395
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

396
        if model_arch in self.models:
397
398
399
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
400
401
402
403
404
405
406
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
407

408
            model = _LazyRegisteredModel(*split_str)
409
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
410
            model = _RegisteredModel.from_model_cls(model_cls)
411
412
413
414
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
415

416
        self.models[model_arch] = model
417

418
    def _raise_for_unsupported(self, architectures: list[str]):
419
        all_supported_archs = self.get_supported_archs()
420

421
422
423
424
425
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

426
427
428
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
429

430
    def _try_load_model_cls(self,
431
                            model_arch: str) -> Optional[type[nn.Module]]:
432
433
        if model_arch not in self.models:
            return None
434

435
        return _try_load_model_cls(model_arch, self.models[model_arch])
436

437
438
439
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
440

441
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
442

443
444
    def _normalize_archs(
        self,
445
446
        architectures: Union[str, list[str]],
    ) -> list[str]:
447
448
449
450
451
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

452
453
454
455
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

456
        # make sure Transformers backend is put at the last as a fallback
457
        if len(normalized_arch) != len(architectures):
458
            normalized_arch.append("TransformersForCausalLM")
459
        return normalized_arch
460

461
462
    def inspect_model_cls(
        self,
463
464
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
465
        architectures = self._normalize_archs(architectures)
466

467
468
469
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
470
                return (model_info, arch)
471

472
        return self._raise_for_unsupported(architectures)
473

474
475
    def resolve_model_cls(
        self,
476
477
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
478
        architectures = self._normalize_archs(architectures)
479

480
481
482
483
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
484

485
        return self._raise_for_unsupported(architectures)
486

487
488
    def is_text_generation_model(
        self,
489
        architectures: Union[str, list[str]],
490
    ) -> bool:
491
492
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
493

494
    def is_pooling_model(
495
        self,
496
        architectures: Union[str, list[str]],
497
    ) -> bool:
498
        model_cls, _ = self.inspect_model_cls(architectures)
499
        return model_cls.is_pooling_model
500

501
502
    def is_cross_encoder_model(
        self,
503
        architectures: Union[str, list[str]],
504
    ) -> bool:
505
506
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
507

508
509
    def is_multimodal_model(
        self,
510
        architectures: Union[str, list[str]],
511
    ) -> bool:
512
513
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
514
515
516

    def is_pp_supported_model(
        self,
517
        architectures: Union[str, list[str]],
518
    ) -> bool:
519
520
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
521

522
523
    def model_has_inner_state(
        self,
524
        architectures: Union[str, list[str]],
525
526
527
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
528

529
530
    def is_attention_free_model(
        self,
531
        architectures: Union[str, list[str]],
532
533
534
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
535

536
537
    def is_hybrid_model(
        self,
538
        architectures: Union[str, list[str]],
539
540
541
542
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

543
544
    def is_noops_model(
        self,
545
        architectures: Union[str, list[str]],
546
547
548
549
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

550
551
    def is_transcription_model(
        self,
552
        architectures: Union[str, list[str]],
553
554
555
556
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

557
558
    def is_v1_compatible(
        self,
559
        architectures: Union[str, list[str]],
560
561
562
563
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

564
565

ModelRegistry = _ModelRegistry({
566
567
    model_arch:
    _LazyRegisteredModel(
568
569
570
571
572
573
574
575
576
577
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
578
579
580
581
582
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

583
        # `cloudpickle` allows pickling lambda functions directly
584
        input_bytes = cloudpickle.dumps((fn, output_filepath))
585
586
587

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
588
589
590
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
591
592
593
594
595
596
597
598
599

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

600
        with open(output_filepath, "rb") as f:
601
602
603
604
605
606
607
608
609
610
611
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
612
613
614

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
615
616
617


if __name__ == "__main__":
618
    _run()