registry.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
from abc import ABC, abstractmethod
13
from collections.abc import Set
14
15
from dataclasses import dataclass, field
from functools import lru_cache
16
from typing import Callable, Optional, TypeVar, Union
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
39
40
41
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
42
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
43
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
44
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
45
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
46
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
49
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
50
51
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
52
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
53
54
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
55
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
56
57
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
58
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
59
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
60
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
61
62
63
64
65
66
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
67
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
68
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
69
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
70
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
71
72
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
73
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
74
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
75
76
77
78
79
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
80
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
81
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
82
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
83
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
84
85
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
86
87
88
89
90
91
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
92
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
93
94
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
95
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
96
97
98
99
100
101
102
103
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
104
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
105
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
106
107
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
108
109
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
110
111
112
113
114
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
115
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
116
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
117
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
118
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
119
120
121
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
122
123
124
}

_EMBEDDING_MODELS = {
125
    # [Text-only]
126
    "BertModel": ("bert", "BertEmbeddingModel"),
127
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
128
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
129
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
130
    "GritLM": ("gritlm", "GritLM"),
131
132
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
133
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
134
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
135
    "LlamaModel": ("llama", "LlamaForCausalLM"),
136
137
138
139
140
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
141
    "MistralModel": ("llama", "LlamaForCausalLM"),
142
    "ModernBertModel": ("modernbert", "ModernBertModel"),
143
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
144
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
145
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
146
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
147
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
148
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
149
150
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
151
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
152
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
153
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
154
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
155
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
156
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
157
158
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
159
160
161
162
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
163
164
}

165
166
167
168
169
170
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
171
172
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
173
174
}

175
_MULTIMODAL_MODELS = {
176
    # [Decoder-only]
177
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
178
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
179
180
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
181
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
182
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
183
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
184
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
185
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
186
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
187
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
188
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
189
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
190
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
191
192
193
194
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
195
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
196
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
197
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
198
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
199
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
200
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
201
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
202
    "Ovis": ("ovis", "Ovis"),
203
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
204
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
205
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
206
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
207
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
208
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
209
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
210
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
211
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
212
    "UltravoxModel": ("ultravox", "UltravoxModel"),
213
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
214
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
215
    # [Encoder-decoder]
216
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
217
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
218
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
219
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
220
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
221
}
222
223

_SPECULATIVE_DECODING_MODELS = {
224
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
225
    "EAGLEModel": ("eagle", "EAGLE"),
226
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
227
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
228
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
229
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
230
231
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
232
}
233

234
_TRANSFORMERS_MODELS = {
235
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
236
}
237
# yapf: enable
238

239
_VLLM_MODELS = {
240
    **_TEXT_GENERATION_MODELS,
241
    **_EMBEDDING_MODELS,
242
    **_CROSS_ENCODER_MODELS,
243
    **_MULTIMODAL_MODELS,
244
    **_SPECULATIVE_DECODING_MODELS,
245
    **_TRANSFORMERS_MODELS,
246
247
}

248
249
250
251
252
253
254
255
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

256

257
258
@dataclass(frozen=True)
class _ModelInfo:
259
    architecture: str
260
    is_text_generation_model: bool
261
    is_pooling_model: bool
262
    supports_cross_encoding: bool
263
264
    supports_multimodal: bool
    supports_pp: bool
265
266
    has_inner_state: bool
    is_attention_free: bool
267
    is_hybrid: bool
268
    has_noops: bool
269
    supports_transcription: bool
270
    supports_v0_only: bool
271
272

    @staticmethod
273
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
274
        return _ModelInfo(
275
            architecture=model.__name__,
276
            is_text_generation_model=is_text_generation_model(model),
277
            is_pooling_model=True,  # Can convert any model into a pooling model
278
            supports_cross_encoding=supports_cross_encoding(model),
279
280
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
281
282
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
283
            is_hybrid=is_hybrid(model),
284
285
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
286
            has_noops=has_noops(model),
287
        )
288
289


290
class _BaseRegisteredModel(ABC):
291

292
293
294
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
295

296
    @abstractmethod
297
    def load_model_cls(self) -> type[nn.Module]:
298
        raise NotImplementedError
299
300


301
302
303
304
305
306
307
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
308
    model_cls: type[nn.Module]
309
310

    @staticmethod
311
    def from_model_cls(model_cls: type[nn.Module]):
312
313
314
315
316
317
318
319
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

320
    def load_model_cls(self) -> type[nn.Module]:
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

337
    def load_model_cls(self) -> type[nn.Module]:
338
339
340
341
342
343
344
345
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
346
) -> Optional[type[nn.Module]]:
347
    from vllm.platforms import current_platform
348
    current_platform.verify_model_arch(model_arch)
349
350
351
352
353
354
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
355
356


357
358
359
360
361
362
363
364
365
366
367
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
368
369


370
371
372
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
373
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
374

375
    def get_supported_archs(self) -> Set[str]:
376
        return self.models.keys()
377

378
379
380
    def register_model(
        self,
        model_arch: str,
381
        model_cls: Union[type[nn.Module], str],
382
    ) -> None:
383
384
385
        """
        Register an external model to be used in vLLM.

386
        `model_cls` can be either:
387

388
        - A [`torch.nn.Module`][] class directly referencing the model.
389
        - A string in the format `<module>:<class>` which can be used to
390
391
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
392
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
393
        """
394
395
396
397
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

398
        if model_arch in self.models:
399
400
401
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
402
403
404
405
406
407
408
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
409

410
            model = _LazyRegisteredModel(*split_str)
411
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
412
            model = _RegisteredModel.from_model_cls(model_cls)
413
414
415
416
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
417

418
        self.models[model_arch] = model
419

420
    def _raise_for_unsupported(self, architectures: list[str]):
421
        all_supported_archs = self.get_supported_archs()
422

423
424
425
426
427
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

428
429
430
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
431

432
    def _try_load_model_cls(self,
433
                            model_arch: str) -> Optional[type[nn.Module]]:
434
435
        if model_arch not in self.models:
            return None
436

437
        return _try_load_model_cls(model_arch, self.models[model_arch])
438

439
440
441
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
442

443
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
444

445
446
    def _normalize_archs(
        self,
447
448
        architectures: Union[str, list[str]],
    ) -> list[str]:
449
450
451
452
453
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

454
455
456
457
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

458
        # make sure Transformers backend is put at the last as a fallback
459
        if len(normalized_arch) != len(architectures):
460
            normalized_arch.append("TransformersForCausalLM")
461
        return normalized_arch
462

463
464
    def inspect_model_cls(
        self,
465
466
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
467
        architectures = self._normalize_archs(architectures)
468

469
470
471
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
472
                return (model_info, arch)
473

474
        return self._raise_for_unsupported(architectures)
475

476
477
    def resolve_model_cls(
        self,
478
479
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
480
        architectures = self._normalize_archs(architectures)
481

482
483
484
485
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
486

487
        return self._raise_for_unsupported(architectures)
488

489
490
    def is_text_generation_model(
        self,
491
        architectures: Union[str, list[str]],
492
    ) -> bool:
493
494
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
495

496
    def is_pooling_model(
497
        self,
498
        architectures: Union[str, list[str]],
499
    ) -> bool:
500
        model_cls, _ = self.inspect_model_cls(architectures)
501
        return model_cls.is_pooling_model
502

503
504
    def is_cross_encoder_model(
        self,
505
        architectures: Union[str, list[str]],
506
    ) -> bool:
507
508
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
509

510
511
    def is_multimodal_model(
        self,
512
        architectures: Union[str, list[str]],
513
    ) -> bool:
514
515
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
516
517
518

    def is_pp_supported_model(
        self,
519
        architectures: Union[str, list[str]],
520
    ) -> bool:
521
522
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
523

524
525
    def model_has_inner_state(
        self,
526
        architectures: Union[str, list[str]],
527
528
529
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
530

531
532
    def is_attention_free_model(
        self,
533
        architectures: Union[str, list[str]],
534
535
536
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
537

538
539
    def is_hybrid_model(
        self,
540
        architectures: Union[str, list[str]],
541
542
543
544
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

545
546
    def is_noops_model(
        self,
547
        architectures: Union[str, list[str]],
548
549
550
551
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

552
553
    def is_transcription_model(
        self,
554
        architectures: Union[str, list[str]],
555
556
557
558
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

559
560
    def is_v1_compatible(
        self,
561
        architectures: Union[str, list[str]],
562
563
564
565
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

566
567

ModelRegistry = _ModelRegistry({
568
569
    model_arch:
    _LazyRegisteredModel(
570
571
572
573
574
575
576
577
578
579
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
580
581
582
583
584
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

585
        # `cloudpickle` allows pickling lambda functions directly
586
        input_bytes = cloudpickle.dumps((fn, output_filepath))
587
588
589

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
590
591
592
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
593
594
595
596
597
598
599
600
601

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

602
        with open(output_filepath, "rb") as f:
603
604
605
606
607
608
609
610
611
612
613
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
614
615
616

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
617
618
619


if __name__ == "__main__":
620
    _run()