registry.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
56
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
57
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
58
59
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
60
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
61
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
62
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
63
64
65
66
67
68
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
69
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
70
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
71
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
72
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
73
74
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
75
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
76
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
77
78
79
80
81
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
82
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
83
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
84
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
85
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
86
87
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
88
89
90
91
92
93
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
94
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
95
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
96
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
97
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
98
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
99
100
101
102
103
104
105
106
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
107
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
108
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
109
110
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
111
112
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
113
114
115
116
117
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
118
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
119
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
120
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
121
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
122
123
124
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
125
126
127
}

_EMBEDDING_MODELS = {
128
    # [Text-only]
129
    "BertModel": ("bert", "BertEmbeddingModel"),
130
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
131
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
132
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
133
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
134
    "GritLM": ("gritlm", "GritLM"),
135
136
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
137
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
138
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
139
    "LlamaModel": ("llama", "LlamaForCausalLM"),
140
141
142
143
144
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
145
    "MistralModel": ("llama", "LlamaForCausalLM"),
146
    "ModernBertModel": ("modernbert", "ModernBertModel"),
147
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
148
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
149
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
150
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
151
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
152
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
153
154
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
155
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
156
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
157
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
158
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
159
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
160
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
161
162
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
163
164
165
166
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
167
168
}

169
170
171
172
173
174
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
175
176
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
177
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
178
179
}

180
_MULTIMODAL_MODELS = {
181
    # [Decoder-only]
182
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
183
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
184
185
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
186
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
187
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
188
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
189
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
190
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
191
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
192
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
193
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
194
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
195
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
196
197
198
199
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
200
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
201
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
202
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
203
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
204
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
205
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
206
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
207
    "Ovis": ("ovis", "Ovis"),
208
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
209
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
210
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
211
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
212
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
213
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
214
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
215
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
216
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
217
    "UltravoxModel": ("ultravox", "UltravoxModel"),
218
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
219
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
220
    # [Encoder-decoder]
221
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
222
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
223
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
224
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
225
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
226
}
227
228

_SPECULATIVE_DECODING_MODELS = {
229
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
230
    "EAGLEModel": ("eagle", "EAGLE"),
231
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
232
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
233
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
234
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
235
236
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
237
}
238

239
_TRANSFORMERS_MODELS = {
240
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
241
}
242
# yapf: enable
243

244
_VLLM_MODELS = {
245
    **_TEXT_GENERATION_MODELS,
246
    **_EMBEDDING_MODELS,
247
    **_CROSS_ENCODER_MODELS,
248
    **_MULTIMODAL_MODELS,
249
    **_SPECULATIVE_DECODING_MODELS,
250
    **_TRANSFORMERS_MODELS,
251
252
}

253
254
255
256
257
258
259
260
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

261

262
263
@dataclass(frozen=True)
class _ModelInfo:
264
    architecture: str
265
    is_text_generation_model: bool
266
    is_pooling_model: bool
267
    supports_cross_encoding: bool
268
269
    supports_multimodal: bool
    supports_pp: bool
270
271
    has_inner_state: bool
    is_attention_free: bool
272
    is_hybrid: bool
273
    has_noops: bool
274
    supports_transcription: bool
275
    supports_v0_only: bool
276
277

    @staticmethod
278
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
279
        return _ModelInfo(
280
            architecture=model.__name__,
281
            is_text_generation_model=is_text_generation_model(model),
282
            is_pooling_model=True,  # Can convert any model into a pooling model
283
            supports_cross_encoding=supports_cross_encoding(model),
284
285
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
286
287
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
288
            is_hybrid=is_hybrid(model),
289
290
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
291
            has_noops=has_noops(model),
292
        )
293
294


295
class _BaseRegisteredModel(ABC):
296

297
298
299
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
300

301
    @abstractmethod
302
    def load_model_cls(self) -> type[nn.Module]:
303
        raise NotImplementedError
304
305


306
307
308
309
310
311
312
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
313
    model_cls: type[nn.Module]
314
315

    @staticmethod
316
    def from_model_cls(model_cls: type[nn.Module]):
317
318
319
320
321
322
323
324
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

325
    def load_model_cls(self) -> type[nn.Module]:
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

342
    def load_model_cls(self) -> type[nn.Module]:
343
344
345
346
347
348
349
350
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
351
) -> Optional[type[nn.Module]]:
352
    from vllm.platforms import current_platform
353
    current_platform.verify_model_arch(model_arch)
354
355
356
357
358
359
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
360
361


362
363
364
365
366
367
368
369
370
371
372
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
373
374


375
376
377
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
378
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
379

380
    def get_supported_archs(self) -> Set[str]:
381
        return self.models.keys()
382

383
384
385
    def register_model(
        self,
        model_arch: str,
386
        model_cls: Union[type[nn.Module], str],
387
    ) -> None:
388
389
390
        """
        Register an external model to be used in vLLM.

391
        `model_cls` can be either:
392

393
        - A [`torch.nn.Module`][] class directly referencing the model.
394
        - A string in the format `<module>:<class>` which can be used to
395
396
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
397
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
398
        """
399
400
401
402
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

403
        if model_arch in self.models:
404
405
406
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
407
408
409
410
411
412
413
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
414

415
            model = _LazyRegisteredModel(*split_str)
416
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
417
            model = _RegisteredModel.from_model_cls(model_cls)
418
419
420
421
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
422

423
        self.models[model_arch] = model
424

425
    def _raise_for_unsupported(self, architectures: list[str]):
426
        all_supported_archs = self.get_supported_archs()
427

428
429
430
431
432
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

433
434
435
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
436

437
    def _try_load_model_cls(self,
438
                            model_arch: str) -> Optional[type[nn.Module]]:
439
440
        if model_arch not in self.models:
            return None
441

442
        return _try_load_model_cls(model_arch, self.models[model_arch])
443

444
445
446
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
447

448
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
449

450
451
    def _normalize_archs(
        self,
452
453
        architectures: Union[str, list[str]],
    ) -> list[str]:
454
455
456
457
458
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

459
460
461
462
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

463
        # make sure Transformers backend is put at the last as a fallback
464
        if len(normalized_arch) != len(architectures):
465
            normalized_arch.append("TransformersForCausalLM")
466
        return normalized_arch
467

468
469
    def inspect_model_cls(
        self,
470
471
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
472
        architectures = self._normalize_archs(architectures)
473

474
475
476
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
477
                return (model_info, arch)
478

479
        return self._raise_for_unsupported(architectures)
480

481
482
    def resolve_model_cls(
        self,
483
484
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
485
        architectures = self._normalize_archs(architectures)
486

487
488
489
490
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
491

492
        return self._raise_for_unsupported(architectures)
493

494
495
    def is_text_generation_model(
        self,
496
        architectures: Union[str, list[str]],
497
    ) -> bool:
498
499
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
500

501
    def is_pooling_model(
502
        self,
503
        architectures: Union[str, list[str]],
504
    ) -> bool:
505
        model_cls, _ = self.inspect_model_cls(architectures)
506
        return model_cls.is_pooling_model
507

508
509
    def is_cross_encoder_model(
        self,
510
        architectures: Union[str, list[str]],
511
    ) -> bool:
512
513
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
514

515
516
    def is_multimodal_model(
        self,
517
        architectures: Union[str, list[str]],
518
    ) -> bool:
519
520
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
521
522
523

    def is_pp_supported_model(
        self,
524
        architectures: Union[str, list[str]],
525
    ) -> bool:
526
527
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
528

529
530
    def model_has_inner_state(
        self,
531
        architectures: Union[str, list[str]],
532
533
534
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
535

536
537
    def is_attention_free_model(
        self,
538
        architectures: Union[str, list[str]],
539
540
541
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
542

543
544
    def is_hybrid_model(
        self,
545
        architectures: Union[str, list[str]],
546
547
548
549
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

550
551
    def is_noops_model(
        self,
552
        architectures: Union[str, list[str]],
553
554
555
556
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

557
558
    def is_transcription_model(
        self,
559
        architectures: Union[str, list[str]],
560
561
562
563
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

564
565
    def is_v1_compatible(
        self,
566
        architectures: Union[str, list[str]],
567
568
569
570
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

571
572

ModelRegistry = _ModelRegistry({
573
574
    model_arch:
    _LazyRegisteredModel(
575
576
577
578
579
580
581
582
583
584
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
585
586
587
588
589
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

590
        # `cloudpickle` allows pickling lambda functions directly
591
        input_bytes = cloudpickle.dumps((fn, output_filepath))
592
593
594

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
595
596
597
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
598
599
600
601
602
603
604
605
606

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

607
        with open(output_filepath, "rb") as f:
608
609
610
611
612
613
614
615
616
617
618
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
619
620
621

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
622
623
624


if __name__ == "__main__":
625
    _run()