registry.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
61
62
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
63
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
64
65
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
66
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
67
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
68
69
70
71
72
73
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
74
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
75
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
76
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
77
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
78
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
79
80
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
81
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
82
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
83
84
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
89
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
90
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
91
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
92
93
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
94
95
96
97
98
99
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
100
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
101
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
102
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
103
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
104
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
105
106
107
108
109
110
111
112
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
113
    "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
114
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
115
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
116
117
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
118
119
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
120
121
122
123
124
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
125
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
126
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
127
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
128
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
129
130
131
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
132
133
134
}

_EMBEDDING_MODELS = {
135
    # [Text-only]
136
    "BertModel": ("bert", "BertEmbeddingModel"),
137
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
138
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
139
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
140
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
141
    "GritLM": ("gritlm", "GritLM"),
142
143
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
144
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
145
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
146
    "LlamaModel": ("llama", "LlamaForCausalLM"),
147
148
149
150
151
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
152
    "MistralModel": ("llama", "LlamaForCausalLM"),
153
    "ModernBertModel": ("modernbert", "ModernBertModel"),
154
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
155
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
156
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
157
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
158
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
159
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
160
161
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
162
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
163
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
164
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
165
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
166
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
167
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
168
169
170
171
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
172
173
}

174
175
176
177
178
179
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
180
181
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
182
    # [Auto-converted (see adapters.py)]
183
    "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
184
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
185
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
186
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
187
188
}

189
_MULTIMODAL_MODELS = {
190
    # [Decoder-only]
191
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
192
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
193
194
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
195
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
196
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
197
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
198
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
199
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
200
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
201
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
202
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
203
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
204
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
205
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
206
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
207
208
209
210
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
211
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
212
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
213
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
214
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
215
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
216
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
217
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
218
    "Ovis": ("ovis", "Ovis"),
219
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
220
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
221
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
222
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
223
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
224
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
225
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
226
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
227
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
228
    "UltravoxModel": ("ultravox", "UltravoxModel"),
229
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
230
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
231
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
232
    # [Encoder-decoder]
233
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
234
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
235
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
236
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
237
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
238
}
239
240

_SPECULATIVE_DECODING_MODELS = {
241
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
242
    "EAGLEModel": ("eagle", "EAGLE"),
243
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
244
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
245
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
246
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
247
248
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
249
}
250

251
_TRANSFORMERS_MODELS = {
252
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
253
}
254
# yapf: enable
255

256
_VLLM_MODELS = {
257
    **_TEXT_GENERATION_MODELS,
258
    **_EMBEDDING_MODELS,
259
    **_CROSS_ENCODER_MODELS,
260
    **_MULTIMODAL_MODELS,
261
    **_SPECULATIVE_DECODING_MODELS,
262
    **_TRANSFORMERS_MODELS,
263
264
}

265
266
267
268
269
270
271
272
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

273

274
275
@dataclass(frozen=True)
class _ModelInfo:
276
    architecture: str
277
    is_text_generation_model: bool
278
    is_pooling_model: bool
279
    supports_cross_encoding: bool
280
281
    supports_multimodal: bool
    supports_pp: bool
282
283
    has_inner_state: bool
    is_attention_free: bool
284
    is_hybrid: bool
285
    has_noops: bool
286
    supports_transcription: bool
287
    supports_v0_only: bool
288
289

    @staticmethod
290
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
291
        return _ModelInfo(
292
            architecture=model.__name__,
293
            is_text_generation_model=is_text_generation_model(model),
294
            is_pooling_model=True,  # Can convert any model into a pooling model
295
            supports_cross_encoding=supports_cross_encoding(model),
296
297
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
298
299
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
300
            is_hybrid=is_hybrid(model),
301
302
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
303
            has_noops=has_noops(model),
304
        )
305
306


307
class _BaseRegisteredModel(ABC):
308

309
310
311
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
312

313
    @abstractmethod
314
    def load_model_cls(self) -> type[nn.Module]:
315
        raise NotImplementedError
316
317


318
319
320
321
322
323
324
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
325
    model_cls: type[nn.Module]
326
327

    @staticmethod
328
    def from_model_cls(model_cls: type[nn.Module]):
329
330
331
332
333
334
335
336
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

337
    def load_model_cls(self) -> type[nn.Module]:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

354
    def load_model_cls(self) -> type[nn.Module]:
355
356
357
358
359
360
361
362
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
363
) -> Optional[type[nn.Module]]:
364
    from vllm.platforms import current_platform
365
    current_platform.verify_model_arch(model_arch)
366
367
368
369
370
371
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
372
373


374
375
376
377
378
379
380
381
382
383
384
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
385
386


387
388
389
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
390
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
391

392
    def get_supported_archs(self) -> Set[str]:
393
        return self.models.keys()
394

395
396
397
    def register_model(
        self,
        model_arch: str,
398
        model_cls: Union[type[nn.Module], str],
399
    ) -> None:
400
401
402
        """
        Register an external model to be used in vLLM.

403
        `model_cls` can be either:
404

405
        - A [`torch.nn.Module`][] class directly referencing the model.
406
        - A string in the format `<module>:<class>` which can be used to
407
408
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
409
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
410
        """
411
412
413
414
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

415
        if model_arch in self.models:
416
417
418
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
419
420
421
422
423
424
425
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
426

427
            model = _LazyRegisteredModel(*split_str)
428
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
429
            model = _RegisteredModel.from_model_cls(model_cls)
430
431
432
433
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
434

435
        self.models[model_arch] = model
436

437
    def _raise_for_unsupported(self, architectures: list[str]):
438
        all_supported_archs = self.get_supported_archs()
439

440
441
442
443
444
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

445
446
447
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
448

449
    def _try_load_model_cls(self,
450
                            model_arch: str) -> Optional[type[nn.Module]]:
451
452
        if model_arch not in self.models:
            return None
453

454
        return _try_load_model_cls(model_arch, self.models[model_arch])
455

456
457
458
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
459

460
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
461

462
463
    def _normalize_archs(
        self,
464
465
        architectures: Union[str, list[str]],
    ) -> list[str]:
466
467
468
469
470
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

471
472
473
474
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

475
        # make sure Transformers backend is put at the last as a fallback
476
        if len(normalized_arch) != len(architectures):
477
            normalized_arch.append("TransformersForCausalLM")
478
        return normalized_arch
479

480
481
    def inspect_model_cls(
        self,
482
483
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
484
        architectures = self._normalize_archs(architectures)
485

486
487
488
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
489
                return (model_info, arch)
490

491
        return self._raise_for_unsupported(architectures)
492

493
494
    def resolve_model_cls(
        self,
495
496
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
497
        architectures = self._normalize_archs(architectures)
498

499
500
501
502
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
503

504
        return self._raise_for_unsupported(architectures)
505

506
507
    def is_text_generation_model(
        self,
508
        architectures: Union[str, list[str]],
509
    ) -> bool:
510
511
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
512

513
    def is_pooling_model(
514
        self,
515
        architectures: Union[str, list[str]],
516
    ) -> bool:
517
        model_cls, _ = self.inspect_model_cls(architectures)
518
        return model_cls.is_pooling_model
519

520
521
    def is_cross_encoder_model(
        self,
522
        architectures: Union[str, list[str]],
523
    ) -> bool:
524
525
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
526

527
528
    def is_multimodal_model(
        self,
529
        architectures: Union[str, list[str]],
530
    ) -> bool:
531
532
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
533
534
535

    def is_pp_supported_model(
        self,
536
        architectures: Union[str, list[str]],
537
    ) -> bool:
538
539
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
540

541
542
    def model_has_inner_state(
        self,
543
        architectures: Union[str, list[str]],
544
545
546
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
547

548
549
    def is_attention_free_model(
        self,
550
        architectures: Union[str, list[str]],
551
552
553
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
554

555
556
    def is_hybrid_model(
        self,
557
        architectures: Union[str, list[str]],
558
559
560
561
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

562
563
    def is_noops_model(
        self,
564
        architectures: Union[str, list[str]],
565
566
567
568
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

569
570
    def is_transcription_model(
        self,
571
        architectures: Union[str, list[str]],
572
573
574
575
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

576
577
    def is_v1_compatible(
        self,
578
        architectures: Union[str, list[str]],
579
580
581
582
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

583
584

ModelRegistry = _ModelRegistry({
585
586
    model_arch:
    _LazyRegisteredModel(
587
588
589
590
591
592
593
594
595
596
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
597
598
599
600
601
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

602
        # `cloudpickle` allows pickling lambda functions directly
603
        import cloudpickle
604
        input_bytes = cloudpickle.dumps((fn, output_filepath))
605
606
607

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
608
609
610
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
611
612
613
614
615
616
617
618
619

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

620
        with open(output_filepath, "rb") as f:
621
622
623
624
625
626
627
628
629
630
631
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
632
633
634

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
635
636
637


if __name__ == "__main__":
638
    _run()