registry.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
58
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
59
60
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
61
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
62
63
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
64
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
65
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
66
67
68
69
70
71
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
72
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
73
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
74
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
75
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
76
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
77
78
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
79
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
80
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
81
82
83
84
85
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
86
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
87
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
88
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
89
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
90
91
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
92
93
94
95
96
97
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
98
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
99
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
100
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
101
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
102
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
103
104
105
106
107
108
109
110
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
111
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
112
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
113
114
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
115
116
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
117
118
119
120
121
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
122
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
123
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
124
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
125
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
126
127
128
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
129
130
131
}

_EMBEDDING_MODELS = {
132
    # [Text-only]
133
    "BertModel": ("bert", "BertEmbeddingModel"),
134
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
135
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
136
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
137
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
138
    "GritLM": ("gritlm", "GritLM"),
139
140
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
141
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
142
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
143
    "LlamaModel": ("llama", "LlamaForCausalLM"),
144
145
146
147
148
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
149
    "MistralModel": ("llama", "LlamaForCausalLM"),
150
    "ModernBertModel": ("modernbert", "ModernBertModel"),
151
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
152
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
153
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
154
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
155
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
156
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
157
158
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
159
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
160
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
161
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
162
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
163
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
164
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
165
166
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
167
168
169
170
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
171
172
}

173
174
175
176
177
178
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
179
180
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
181
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
182
183
}

184
_MULTIMODAL_MODELS = {
185
    # [Decoder-only]
186
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
187
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
188
189
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
190
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
191
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
192
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
193
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
194
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
195
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
196
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
197
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
198
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
199
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
200
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
201
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
202
203
204
205
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
206
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
207
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
208
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
209
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
210
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
211
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
212
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
213
    "Ovis": ("ovis", "Ovis"),
214
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
215
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
216
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
217
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
218
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
219
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
220
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
221
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
222
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
223
    "UltravoxModel": ("ultravox", "UltravoxModel"),
224
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
225
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
226
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
227
    # [Encoder-decoder]
228
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
229
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
230
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
231
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
232
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
233
}
234
235

_SPECULATIVE_DECODING_MODELS = {
236
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
237
    "EAGLEModel": ("eagle", "EAGLE"),
238
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
239
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
240
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
241
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
242
243
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
244
}
245

246
_TRANSFORMERS_MODELS = {
247
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
248
}
249
# yapf: enable
250

251
_VLLM_MODELS = {
252
    **_TEXT_GENERATION_MODELS,
253
    **_EMBEDDING_MODELS,
254
    **_CROSS_ENCODER_MODELS,
255
    **_MULTIMODAL_MODELS,
256
    **_SPECULATIVE_DECODING_MODELS,
257
    **_TRANSFORMERS_MODELS,
258
259
}

260
261
262
263
264
265
266
267
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

268

269
270
@dataclass(frozen=True)
class _ModelInfo:
271
    architecture: str
272
    is_text_generation_model: bool
273
    is_pooling_model: bool
274
    supports_cross_encoding: bool
275
276
    supports_multimodal: bool
    supports_pp: bool
277
278
    has_inner_state: bool
    is_attention_free: bool
279
    is_hybrid: bool
280
    has_noops: bool
281
    supports_transcription: bool
282
    supports_v0_only: bool
283
284

    @staticmethod
285
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
286
        return _ModelInfo(
287
            architecture=model.__name__,
288
            is_text_generation_model=is_text_generation_model(model),
289
            is_pooling_model=True,  # Can convert any model into a pooling model
290
            supports_cross_encoding=supports_cross_encoding(model),
291
292
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
293
294
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
295
            is_hybrid=is_hybrid(model),
296
297
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
298
            has_noops=has_noops(model),
299
        )
300
301


302
class _BaseRegisteredModel(ABC):
303

304
305
306
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
307

308
    @abstractmethod
309
    def load_model_cls(self) -> type[nn.Module]:
310
        raise NotImplementedError
311
312


313
314
315
316
317
318
319
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
320
    model_cls: type[nn.Module]
321
322

    @staticmethod
323
    def from_model_cls(model_cls: type[nn.Module]):
324
325
326
327
328
329
330
331
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

332
    def load_model_cls(self) -> type[nn.Module]:
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

349
    def load_model_cls(self) -> type[nn.Module]:
350
351
352
353
354
355
356
357
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
358
) -> Optional[type[nn.Module]]:
359
    from vllm.platforms import current_platform
360
    current_platform.verify_model_arch(model_arch)
361
362
363
364
365
366
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
367
368


369
370
371
372
373
374
375
376
377
378
379
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
380
381


382
383
384
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
385
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
386

387
    def get_supported_archs(self) -> Set[str]:
388
        return self.models.keys()
389

390
391
392
    def register_model(
        self,
        model_arch: str,
393
        model_cls: Union[type[nn.Module], str],
394
    ) -> None:
395
396
397
        """
        Register an external model to be used in vLLM.

398
        `model_cls` can be either:
399

400
        - A [`torch.nn.Module`][] class directly referencing the model.
401
        - A string in the format `<module>:<class>` which can be used to
402
403
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
404
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
405
        """
406
407
408
409
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

410
        if model_arch in self.models:
411
412
413
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
414
415
416
417
418
419
420
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
421

422
            model = _LazyRegisteredModel(*split_str)
423
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
424
            model = _RegisteredModel.from_model_cls(model_cls)
425
426
427
428
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
429

430
        self.models[model_arch] = model
431

432
    def _raise_for_unsupported(self, architectures: list[str]):
433
        all_supported_archs = self.get_supported_archs()
434

435
436
437
438
439
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

440
441
442
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
443

444
    def _try_load_model_cls(self,
445
                            model_arch: str) -> Optional[type[nn.Module]]:
446
447
        if model_arch not in self.models:
            return None
448

449
        return _try_load_model_cls(model_arch, self.models[model_arch])
450

451
452
453
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
454

455
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
456

457
458
    def _normalize_archs(
        self,
459
460
        architectures: Union[str, list[str]],
    ) -> list[str]:
461
462
463
464
465
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

466
467
468
469
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

470
        # make sure Transformers backend is put at the last as a fallback
471
        if len(normalized_arch) != len(architectures):
472
            normalized_arch.append("TransformersForCausalLM")
473
        return normalized_arch
474

475
476
    def inspect_model_cls(
        self,
477
478
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
479
        architectures = self._normalize_archs(architectures)
480

481
482
483
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
484
                return (model_info, arch)
485

486
        return self._raise_for_unsupported(architectures)
487

488
489
    def resolve_model_cls(
        self,
490
491
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
492
        architectures = self._normalize_archs(architectures)
493

494
495
496
497
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
498

499
        return self._raise_for_unsupported(architectures)
500

501
502
    def is_text_generation_model(
        self,
503
        architectures: Union[str, list[str]],
504
    ) -> bool:
505
506
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
507

508
    def is_pooling_model(
509
        self,
510
        architectures: Union[str, list[str]],
511
    ) -> bool:
512
        model_cls, _ = self.inspect_model_cls(architectures)
513
        return model_cls.is_pooling_model
514

515
516
    def is_cross_encoder_model(
        self,
517
        architectures: Union[str, list[str]],
518
    ) -> bool:
519
520
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
521

522
523
    def is_multimodal_model(
        self,
524
        architectures: Union[str, list[str]],
525
    ) -> bool:
526
527
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
528
529
530

    def is_pp_supported_model(
        self,
531
        architectures: Union[str, list[str]],
532
    ) -> bool:
533
534
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
535

536
537
    def model_has_inner_state(
        self,
538
        architectures: Union[str, list[str]],
539
540
541
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
542

543
544
    def is_attention_free_model(
        self,
545
        architectures: Union[str, list[str]],
546
547
548
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
549

550
551
    def is_hybrid_model(
        self,
552
        architectures: Union[str, list[str]],
553
554
555
556
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

557
558
    def is_noops_model(
        self,
559
        architectures: Union[str, list[str]],
560
561
562
563
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

564
565
    def is_transcription_model(
        self,
566
        architectures: Union[str, list[str]],
567
568
569
570
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

571
572
    def is_v1_compatible(
        self,
573
        architectures: Union[str, list[str]],
574
575
576
577
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

578
579

ModelRegistry = _ModelRegistry({
580
581
    model_arch:
    _LazyRegisteredModel(
582
583
584
585
586
587
588
589
590
591
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
592
593
594
595
596
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

597
        # `cloudpickle` allows pickling lambda functions directly
598
        input_bytes = cloudpickle.dumps((fn, output_filepath))
599
600
601

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
602
603
604
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
605
606
607
608
609
610
611
612
613

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

614
        with open(output_filepath, "rb") as f:
615
616
617
618
619
620
621
622
623
624
625
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
626
627
628

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
629
630
631


if __name__ == "__main__":
632
    _run()