registry.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "FM9GForCausalLM": ("fm9g", "FM9GForCausalLM"),
61
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
62
63
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
64
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
65
66
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
67
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
68
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
69
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
70
71
72
73
74
75
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
76
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
77
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
78
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
79
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
80
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
81
82
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
83
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
84
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
89
90
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
92
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
93
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
94
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
95
96
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
97
98
99
100
101
102
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
103
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
104
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
105
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
106
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
107
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
108
109
110
111
112
113
114
115
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
116
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
117
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
118
119
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
120
121
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
122
123
124
125
126
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
127
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
128
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
129
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
130
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
131
132
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
133
134
135
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
136
137
138
}

_EMBEDDING_MODELS = {
139
    # [Text-only]
140
    "BertModel": ("bert", "BertEmbeddingModel"),
141
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
142
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
143
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
144
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
145
    "GritLM": ("gritlm", "GritLM"),
146
147
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
148
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
149
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
150
    "LlamaModel": ("llama", "LlamaForCausalLM"),
151
152
153
154
155
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
156
    "MistralModel": ("llama", "LlamaForCausalLM"),
157
    "ModernBertModel": ("modernbert", "ModernBertModel"),
158
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
159
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
160
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
161
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
162
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
163
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
164
165
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
166
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
167
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
168
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
169
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
170
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
171
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
172
173
174
175
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
176
177
}

178
179
180
181
182
183
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
184
185
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
186
187
188
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"),  # noqa: E501
189
190
}

191
_MULTIMODAL_MODELS = {
192
    # [Decoder-only]
193
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
194
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
195
196
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
197
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
198
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
199
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
200
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
201
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
202
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
203
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
204
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
205
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
206
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
207
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
208
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
209
210
211
212
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
213
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
214
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
215
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
216
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
217
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
218
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
219
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
220
    "Ovis": ("ovis", "Ovis"),
221
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
222
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
223
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
224
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
225
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
226
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
227
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
228
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
229
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
230
    "UltravoxModel": ("ultravox", "UltravoxModel"),
231
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
232
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
233
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
234
    # [Encoder-decoder]
235
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
236
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
237
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
238
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
239
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
240
}
241
242

_SPECULATIVE_DECODING_MODELS = {
243
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
244
    "EAGLEModel": ("eagle", "EAGLE"),
245
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
246
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
247
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
248
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
zhuwenwen's avatar
zhuwenwen committed
249
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
250
251
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
252
}
253

254
_TRANSFORMERS_MODELS = {
255
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
256
}
257
# yapf: enable
258

259
_VLLM_MODELS = {
260
    **_TEXT_GENERATION_MODELS,
261
    **_EMBEDDING_MODELS,
262
    **_CROSS_ENCODER_MODELS,
263
    **_MULTIMODAL_MODELS,
264
    **_SPECULATIVE_DECODING_MODELS,
265
    **_TRANSFORMERS_MODELS,
266
267
}

268
269
270
271
272
273
274
275
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

276

277
278
@dataclass(frozen=True)
class _ModelInfo:
279
    architecture: str
280
    is_text_generation_model: bool
281
    is_pooling_model: bool
282
    supports_cross_encoding: bool
283
284
    supports_multimodal: bool
    supports_pp: bool
285
286
    has_inner_state: bool
    is_attention_free: bool
287
    is_hybrid: bool
288
    has_noops: bool
289
    supports_transcription: bool
290
    supports_v0_only: bool
291
292

    @staticmethod
293
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
294
        return _ModelInfo(
295
            architecture=model.__name__,
296
            is_text_generation_model=is_text_generation_model(model),
297
            is_pooling_model=True,  # Can convert any model into a pooling model
298
            supports_cross_encoding=supports_cross_encoding(model),
299
300
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
301
302
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
303
            is_hybrid=is_hybrid(model),
304
305
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
306
            has_noops=has_noops(model),
307
        )
308
309


310
class _BaseRegisteredModel(ABC):
311

312
313
314
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
315

316
    @abstractmethod
317
    def load_model_cls(self) -> type[nn.Module]:
318
        raise NotImplementedError
319
320


321
322
323
324
325
326
327
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
328
    model_cls: type[nn.Module]
329
330

    @staticmethod
331
    def from_model_cls(model_cls: type[nn.Module]):
332
333
334
335
336
337
338
339
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

340
    def load_model_cls(self) -> type[nn.Module]:
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

357
    def load_model_cls(self) -> type[nn.Module]:
358
359
360
361
362
363
364
365
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
366
) -> Optional[type[nn.Module]]:
367
    from vllm.platforms import current_platform
368
    current_platform.verify_model_arch(model_arch)
369
370
371
372
373
374
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
375
376


377
378
379
380
381
382
383
384
385
386
387
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
388
389


390
391
392
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
393
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
394

395
    def get_supported_archs(self) -> Set[str]:
396
        return self.models.keys()
397

398
399
400
    def register_model(
        self,
        model_arch: str,
401
        model_cls: Union[type[nn.Module], str],
402
    ) -> None:
403
404
405
        """
        Register an external model to be used in vLLM.

406
        `model_cls` can be either:
407

408
        - A [`torch.nn.Module`][] class directly referencing the model.
409
        - A string in the format `<module>:<class>` which can be used to
410
411
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
412
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
413
        """
414
415
416
417
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

418
        if model_arch in self.models:
419
420
421
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
422
423
424
425
426
427
428
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
429

430
            model = _LazyRegisteredModel(*split_str)
431
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
432
            model = _RegisteredModel.from_model_cls(model_cls)
433
434
435
436
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
437

438
        self.models[model_arch] = model
439

440
    def _raise_for_unsupported(self, architectures: list[str]):
441
        all_supported_archs = self.get_supported_archs()
442

443
444
445
446
447
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

448
449
450
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
451

452
    def _try_load_model_cls(self,
453
                            model_arch: str) -> Optional[type[nn.Module]]:
454
455
        if model_arch not in self.models:
            return None
456

457
        return _try_load_model_cls(model_arch, self.models[model_arch])
458

459
460
461
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
462

463
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
464

465
466
    def _normalize_archs(
        self,
467
468
        architectures: Union[str, list[str]],
    ) -> list[str]:
469
470
471
472
473
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

474
475
476
477
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

478
        # make sure Transformers backend is put at the last as a fallback
479
        if len(normalized_arch) != len(architectures):
480
            normalized_arch.append("TransformersForCausalLM")
481
        return normalized_arch
482

483
484
    def inspect_model_cls(
        self,
485
486
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
487
        architectures = self._normalize_archs(architectures)
488

489
490
491
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
492
                return (model_info, arch)
493

494
        return self._raise_for_unsupported(architectures)
495

496
497
    def resolve_model_cls(
        self,
498
499
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
500
        architectures = self._normalize_archs(architectures)
501

502
503
504
505
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
506

507
        return self._raise_for_unsupported(architectures)
508

509
510
    def is_text_generation_model(
        self,
511
        architectures: Union[str, list[str]],
512
    ) -> bool:
513
514
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
515

516
    def is_pooling_model(
517
        self,
518
        architectures: Union[str, list[str]],
519
    ) -> bool:
520
        model_cls, _ = self.inspect_model_cls(architectures)
521
        return model_cls.is_pooling_model
522

523
524
    def is_cross_encoder_model(
        self,
525
        architectures: Union[str, list[str]],
526
    ) -> bool:
527
528
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
529

530
531
    def is_multimodal_model(
        self,
532
        architectures: Union[str, list[str]],
533
    ) -> bool:
534
535
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
536
537
538

    def is_pp_supported_model(
        self,
539
        architectures: Union[str, list[str]],
540
    ) -> bool:
541
542
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
543

544
545
    def model_has_inner_state(
        self,
546
        architectures: Union[str, list[str]],
547
548
549
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
550

551
552
    def is_attention_free_model(
        self,
553
        architectures: Union[str, list[str]],
554
555
556
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
557

558
559
    def is_hybrid_model(
        self,
560
        architectures: Union[str, list[str]],
561
562
563
564
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

565
566
    def is_noops_model(
        self,
567
        architectures: Union[str, list[str]],
568
569
570
571
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

572
573
    def is_transcription_model(
        self,
574
        architectures: Union[str, list[str]],
575
576
577
578
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

579
580
    def is_v1_compatible(
        self,
581
        architectures: Union[str, list[str]],
582
583
584
585
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

586
587

ModelRegistry = _ModelRegistry({
588
589
    model_arch:
    _LazyRegisteredModel(
590
591
592
593
594
595
596
597
598
599
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
600
601
602
603
604
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

605
        # `cloudpickle` allows pickling lambda functions directly
606
        input_bytes = cloudpickle.dumps((fn, output_filepath))
607
608
609

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
610
611
612
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
613
614
615
616
617
618
619
620
621

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

622
        with open(output_filepath, "rb") as f:
623
624
625
626
627
628
629
630
631
632
633
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
634
635
636

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
637
638
639


if __name__ == "__main__":
640
    _run()