registry.py 26.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
61
62
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
63
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
64
65
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
66
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
67
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
68
69
70
71
72
73
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
74
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
75
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
76
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
77
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
78
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
79
80
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
81
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
82
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
83
84
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
89
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
90
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
91
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
92
93
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
94
95
96
97
98
99
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
100
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
101
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
102
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
103
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
104
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
105
106
107
108
109
110
111
112
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
113
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
114
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
115
116
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
117
118
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
119
120
121
122
123
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
124
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
125
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
126
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
127
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
128
129
130
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
131
132
133
}

_EMBEDDING_MODELS = {
134
    # [Text-only]
135
    "BertModel": ("bert", "BertEmbeddingModel"),
136
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
137
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
138
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
139
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
140
    "GritLM": ("gritlm", "GritLM"),
141
142
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
143
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
144
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
145
    "LlamaModel": ("llama", "LlamaForCausalLM"),
146
147
148
149
150
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
151
    "MistralModel": ("llama", "LlamaForCausalLM"),
152
    "ModernBertModel": ("modernbert", "ModernBertModel"),
153
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
154
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
155
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
156
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
157
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
158
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
159
160
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
161
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
162
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
163
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
164
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
165
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
166
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
167
168
169
170
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
171
172
}

173
174
175
176
177
178
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
179
180
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
181
182
183
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"),  # noqa: E501
184
185
}

186
_MULTIMODAL_MODELS = {
187
    # [Decoder-only]
188
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
189
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
190
191
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
192
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
193
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
194
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
195
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
196
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
197
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
198
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
199
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
200
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
201
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
202
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
203
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
204
205
206
207
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
208
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
209
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
210
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
211
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
212
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
213
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
214
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
215
    "Ovis": ("ovis", "Ovis"),
216
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
217
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
218
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
219
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
220
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
221
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
222
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
223
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
224
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
225
    "UltravoxModel": ("ultravox", "UltravoxModel"),
226
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
227
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
228
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
229
    # [Encoder-decoder]
230
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
231
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
232
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
233
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
234
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
235
}
236
237

_SPECULATIVE_DECODING_MODELS = {
238
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
239
    "EAGLEModel": ("eagle", "EAGLE"),
240
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
241
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
242
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
243
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
244
245
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
246
}
247

248
_TRANSFORMERS_MODELS = {
249
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
250
}
251
# yapf: enable
252

253
_VLLM_MODELS = {
254
    **_TEXT_GENERATION_MODELS,
255
    **_EMBEDDING_MODELS,
256
    **_CROSS_ENCODER_MODELS,
257
    **_MULTIMODAL_MODELS,
258
    **_SPECULATIVE_DECODING_MODELS,
259
    **_TRANSFORMERS_MODELS,
260
261
}

262
263
264
265
266
267
268
269
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

270

271
272
@dataclass(frozen=True)
class _ModelInfo:
273
    architecture: str
274
    is_text_generation_model: bool
275
    is_pooling_model: bool
276
    supports_cross_encoding: bool
277
278
    supports_multimodal: bool
    supports_pp: bool
279
280
    has_inner_state: bool
    is_attention_free: bool
281
    is_hybrid: bool
282
    has_noops: bool
283
    supports_transcription: bool
284
    supports_v0_only: bool
285
286

    @staticmethod
287
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
288
        return _ModelInfo(
289
            architecture=model.__name__,
290
            is_text_generation_model=is_text_generation_model(model),
291
            is_pooling_model=True,  # Can convert any model into a pooling model
292
            supports_cross_encoding=supports_cross_encoding(model),
293
294
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
295
296
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
297
            is_hybrid=is_hybrid(model),
298
299
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
300
            has_noops=has_noops(model),
301
        )
302
303


304
class _BaseRegisteredModel(ABC):
305

306
307
308
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
309

310
    @abstractmethod
311
    def load_model_cls(self) -> type[nn.Module]:
312
        raise NotImplementedError
313
314


315
316
317
318
319
320
321
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
322
    model_cls: type[nn.Module]
323
324

    @staticmethod
325
    def from_model_cls(model_cls: type[nn.Module]):
326
327
328
329
330
331
332
333
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

334
    def load_model_cls(self) -> type[nn.Module]:
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

351
    def load_model_cls(self) -> type[nn.Module]:
352
353
354
355
356
357
358
359
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
360
) -> Optional[type[nn.Module]]:
361
    from vllm.platforms import current_platform
362
    current_platform.verify_model_arch(model_arch)
363
364
365
366
367
368
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
369
370


371
372
373
374
375
376
377
378
379
380
381
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
382
383


384
385
386
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
387
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
388

389
    def get_supported_archs(self) -> Set[str]:
390
        return self.models.keys()
391

392
393
394
    def register_model(
        self,
        model_arch: str,
395
        model_cls: Union[type[nn.Module], str],
396
    ) -> None:
397
398
399
        """
        Register an external model to be used in vLLM.

400
        `model_cls` can be either:
401

402
        - A [`torch.nn.Module`][] class directly referencing the model.
403
        - A string in the format `<module>:<class>` which can be used to
404
405
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
406
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
407
        """
408
409
410
411
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

412
        if model_arch in self.models:
413
414
415
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
416
417
418
419
420
421
422
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
423

424
            model = _LazyRegisteredModel(*split_str)
425
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
426
            model = _RegisteredModel.from_model_cls(model_cls)
427
428
429
430
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
431

432
        self.models[model_arch] = model
433

434
    def _raise_for_unsupported(self, architectures: list[str]):
435
        all_supported_archs = self.get_supported_archs()
436

437
438
439
440
441
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

442
443
444
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
445

446
    def _try_load_model_cls(self,
447
                            model_arch: str) -> Optional[type[nn.Module]]:
448
449
        if model_arch not in self.models:
            return None
450

451
        return _try_load_model_cls(model_arch, self.models[model_arch])
452

453
454
455
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
456

457
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
458

459
460
    def _normalize_archs(
        self,
461
462
        architectures: Union[str, list[str]],
    ) -> list[str]:
463
464
465
466
467
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

468
469
470
471
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

472
        # make sure Transformers backend is put at the last as a fallback
473
        if len(normalized_arch) != len(architectures):
474
            normalized_arch.append("TransformersForCausalLM")
475
        return normalized_arch
476

477
478
    def inspect_model_cls(
        self,
479
480
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
481
        architectures = self._normalize_archs(architectures)
482

483
484
485
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
486
                return (model_info, arch)
487

488
        return self._raise_for_unsupported(architectures)
489

490
491
    def resolve_model_cls(
        self,
492
493
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
494
        architectures = self._normalize_archs(architectures)
495

496
497
498
499
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
500

501
        return self._raise_for_unsupported(architectures)
502

503
504
    def is_text_generation_model(
        self,
505
        architectures: Union[str, list[str]],
506
    ) -> bool:
507
508
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
509

510
    def is_pooling_model(
511
        self,
512
        architectures: Union[str, list[str]],
513
    ) -> bool:
514
        model_cls, _ = self.inspect_model_cls(architectures)
515
        return model_cls.is_pooling_model
516

517
518
    def is_cross_encoder_model(
        self,
519
        architectures: Union[str, list[str]],
520
    ) -> bool:
521
522
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
523

524
525
    def is_multimodal_model(
        self,
526
        architectures: Union[str, list[str]],
527
    ) -> bool:
528
529
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
530
531
532

    def is_pp_supported_model(
        self,
533
        architectures: Union[str, list[str]],
534
    ) -> bool:
535
536
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
537

538
539
    def model_has_inner_state(
        self,
540
        architectures: Union[str, list[str]],
541
542
543
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
544

545
546
    def is_attention_free_model(
        self,
547
        architectures: Union[str, list[str]],
548
549
550
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
551

552
553
    def is_hybrid_model(
        self,
554
        architectures: Union[str, list[str]],
555
556
557
558
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

559
560
    def is_noops_model(
        self,
561
        architectures: Union[str, list[str]],
562
563
564
565
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

566
567
    def is_transcription_model(
        self,
568
        architectures: Union[str, list[str]],
569
570
571
572
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

573
574
    def is_v1_compatible(
        self,
575
        architectures: Union[str, list[str]],
576
577
578
579
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

580
581

ModelRegistry = _ModelRegistry({
582
583
    model_arch:
    _LazyRegisteredModel(
584
585
586
587
588
589
590
591
592
593
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
594
595
596
597
598
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

599
        # `cloudpickle` allows pickling lambda functions directly
600
        input_bytes = cloudpickle.dumps((fn, output_filepath))
601
602
603

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
604
605
606
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
607
608
609
610
611
612
613
614
615

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

616
        with open(output_filepath, "rb") as f:
617
618
619
620
621
622
623
624
625
626
627
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
628
629
630

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
631
632
633


if __name__ == "__main__":
634
    _run()