registry.py 26.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
61
62
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
63
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
64
65
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
66
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
67
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
68
69
70
71
72
73
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
74
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
75
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
76
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
77
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
78
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
79
80
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
81
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
82
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
83
84
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
89
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
90
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
91
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
92
93
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
94
95
96
97
98
99
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
100
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
101
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
102
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
103
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
104
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
105
106
107
108
109
110
111
112
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
113
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
114
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
115
116
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
117
118
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
119
120
121
122
123
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
124
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
125
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
126
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
127
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
128
129
130
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
131
132
133
}

_EMBEDDING_MODELS = {
134
    # [Text-only]
135
    "BertModel": ("bert", "BertEmbeddingModel"),
136
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
137
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
138
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
139
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
140
    "GritLM": ("gritlm", "GritLM"),
141
142
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
143
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
144
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
145
    "LlamaModel": ("llama", "LlamaForCausalLM"),
146
147
148
149
150
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
151
    "MistralModel": ("llama", "LlamaForCausalLM"),
152
    "ModernBertModel": ("modernbert", "ModernBertModel"),
153
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
154
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
155
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
156
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
157
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
158
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
159
160
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
161
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
162
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
163
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
164
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
165
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
166
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
167
168
169
170
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
171
172
}

173
174
175
176
177
178
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
179
180
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
181
    # [Auto-converted (see adapters.py)]
182
    "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
183
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
184
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
185
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
186
187
}

188
_MULTIMODAL_MODELS = {
189
    # [Decoder-only]
190
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
191
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
192
193
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
194
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
195
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
196
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
197
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
198
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
199
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
200
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
201
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
202
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
203
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
204
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
205
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
206
207
208
209
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
210
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
211
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
212
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
213
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
214
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
215
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
216
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
217
    "Ovis": ("ovis", "Ovis"),
218
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
219
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
220
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
221
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
222
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
223
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
224
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
225
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
226
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
227
    "UltravoxModel": ("ultravox", "UltravoxModel"),
228
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
229
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
230
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
231
    # [Encoder-decoder]
232
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
233
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
234
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
235
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
236
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
237
}
238
239

_SPECULATIVE_DECODING_MODELS = {
240
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
241
    "EAGLEModel": ("eagle", "EAGLE"),
242
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
243
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
244
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
245
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
246
247
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
248
}
249

250
_TRANSFORMERS_MODELS = {
251
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
252
}
253
# yapf: enable
254

255
_VLLM_MODELS = {
256
    **_TEXT_GENERATION_MODELS,
257
    **_EMBEDDING_MODELS,
258
    **_CROSS_ENCODER_MODELS,
259
    **_MULTIMODAL_MODELS,
260
    **_SPECULATIVE_DECODING_MODELS,
261
    **_TRANSFORMERS_MODELS,
262
263
}

264
265
266
267
268
269
270
271
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

272

273
274
@dataclass(frozen=True)
class _ModelInfo:
275
    architecture: str
276
    is_text_generation_model: bool
277
    is_pooling_model: bool
278
    supports_cross_encoding: bool
279
280
    supports_multimodal: bool
    supports_pp: bool
281
282
    has_inner_state: bool
    is_attention_free: bool
283
    is_hybrid: bool
284
    has_noops: bool
285
    supports_transcription: bool
286
    supports_v0_only: bool
287
288

    @staticmethod
289
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
290
        return _ModelInfo(
291
            architecture=model.__name__,
292
            is_text_generation_model=is_text_generation_model(model),
293
            is_pooling_model=True,  # Can convert any model into a pooling model
294
            supports_cross_encoding=supports_cross_encoding(model),
295
296
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
297
298
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
299
            is_hybrid=is_hybrid(model),
300
301
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
302
            has_noops=has_noops(model),
303
        )
304
305


306
class _BaseRegisteredModel(ABC):
307

308
309
310
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
311

312
    @abstractmethod
313
    def load_model_cls(self) -> type[nn.Module]:
314
        raise NotImplementedError
315
316


317
318
319
320
321
322
323
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
324
    model_cls: type[nn.Module]
325
326

    @staticmethod
327
    def from_model_cls(model_cls: type[nn.Module]):
328
329
330
331
332
333
334
335
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

336
    def load_model_cls(self) -> type[nn.Module]:
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

353
    def load_model_cls(self) -> type[nn.Module]:
354
355
356
357
358
359
360
361
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
362
) -> Optional[type[nn.Module]]:
363
    from vllm.platforms import current_platform
364
    current_platform.verify_model_arch(model_arch)
365
366
367
368
369
370
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
371
372


373
374
375
376
377
378
379
380
381
382
383
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
384
385


386
387
388
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
389
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
390

391
    def get_supported_archs(self) -> Set[str]:
392
        return self.models.keys()
393

394
395
396
    def register_model(
        self,
        model_arch: str,
397
        model_cls: Union[type[nn.Module], str],
398
    ) -> None:
399
400
401
        """
        Register an external model to be used in vLLM.

402
        `model_cls` can be either:
403

404
        - A [`torch.nn.Module`][] class directly referencing the model.
405
        - A string in the format `<module>:<class>` which can be used to
406
407
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
408
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
409
        """
410
411
412
413
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

414
        if model_arch in self.models:
415
416
417
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
418
419
420
421
422
423
424
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
425

426
            model = _LazyRegisteredModel(*split_str)
427
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
428
            model = _RegisteredModel.from_model_cls(model_cls)
429
430
431
432
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
433

434
        self.models[model_arch] = model
435

436
    def _raise_for_unsupported(self, architectures: list[str]):
437
        all_supported_archs = self.get_supported_archs()
438

439
440
441
442
443
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

444
445
446
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
447

448
    def _try_load_model_cls(self,
449
                            model_arch: str) -> Optional[type[nn.Module]]:
450
451
        if model_arch not in self.models:
            return None
452

453
        return _try_load_model_cls(model_arch, self.models[model_arch])
454

455
456
457
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
458

459
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
460

461
462
    def _normalize_archs(
        self,
463
464
        architectures: Union[str, list[str]],
    ) -> list[str]:
465
466
467
468
469
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

470
471
472
473
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

474
        # make sure Transformers backend is put at the last as a fallback
475
        if len(normalized_arch) != len(architectures):
476
            normalized_arch.append("TransformersForCausalLM")
477
        return normalized_arch
478

479
480
    def inspect_model_cls(
        self,
481
482
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
483
        architectures = self._normalize_archs(architectures)
484

485
486
487
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
488
                return (model_info, arch)
489

490
        return self._raise_for_unsupported(architectures)
491

492
493
    def resolve_model_cls(
        self,
494
495
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
496
        architectures = self._normalize_archs(architectures)
497

498
499
500
501
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
502

503
        return self._raise_for_unsupported(architectures)
504

505
506
    def is_text_generation_model(
        self,
507
        architectures: Union[str, list[str]],
508
    ) -> bool:
509
510
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
511

512
    def is_pooling_model(
513
        self,
514
        architectures: Union[str, list[str]],
515
    ) -> bool:
516
        model_cls, _ = self.inspect_model_cls(architectures)
517
        return model_cls.is_pooling_model
518

519
520
    def is_cross_encoder_model(
        self,
521
        architectures: Union[str, list[str]],
522
    ) -> bool:
523
524
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
525

526
527
    def is_multimodal_model(
        self,
528
        architectures: Union[str, list[str]],
529
    ) -> bool:
530
531
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
532
533
534

    def is_pp_supported_model(
        self,
535
        architectures: Union[str, list[str]],
536
    ) -> bool:
537
538
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
539

540
541
    def model_has_inner_state(
        self,
542
        architectures: Union[str, list[str]],
543
544
545
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
546

547
548
    def is_attention_free_model(
        self,
549
        architectures: Union[str, list[str]],
550
551
552
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
553

554
555
    def is_hybrid_model(
        self,
556
        architectures: Union[str, list[str]],
557
558
559
560
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

561
562
    def is_noops_model(
        self,
563
        architectures: Union[str, list[str]],
564
565
566
567
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

568
569
    def is_transcription_model(
        self,
570
        architectures: Union[str, list[str]],
571
572
573
574
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

575
576
    def is_v1_compatible(
        self,
577
        architectures: Union[str, list[str]],
578
579
580
581
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

582
583

ModelRegistry = _ModelRegistry({
584
585
    model_arch:
    _LazyRegisteredModel(
586
587
588
589
590
591
592
593
594
595
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
596
597
598
599
600
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

601
        # `cloudpickle` allows pickling lambda functions directly
602
        import cloudpickle
603
        input_bytes = cloudpickle.dumps((fn, output_filepath))
604
605
606

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
607
608
609
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
610
611
612
613
614
615
616
617
618

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

619
        with open(output_filepath, "rb") as f:
620
621
622
623
624
625
626
627
628
629
630
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
631
632
633

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
634
635
636


if __name__ == "__main__":
637
    _run()