"vllm/vscode:/vscode.git/clone" did not exist on "65097ca0af5c1d7caa3d9d8224fa8b4790a5f7bc"
registry.py 26.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18
19
20
21
22

import torch.nn as nn

from vllm.logger import init_logger

23
24
25
26
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
27
from .interfaces_base import is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
38
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
40
41
42
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
43
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
44
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
45
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
46
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
47
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
50
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
51
52
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
53
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
54
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
55
56
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
57
58
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
59
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
60
61
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
62
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
63
64
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
65
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
66
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
67
68
69
70
71
72
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
73
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
74
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
75
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
76
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
77
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
78
79
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
80
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
81
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
82
83
84
85
86
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
87
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
88
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
89
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
90
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
91
92
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
93
94
95
96
97
98
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
99
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
100
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
101
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
102
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
103
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
104
105
106
107
108
109
110
111
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
112
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
113
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
114
115
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
116
117
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
118
119
120
121
122
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
123
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
124
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
125
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
126
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
127
128
129
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
130
131
132
}

_EMBEDDING_MODELS = {
133
    # [Text-only]
134
    "BertModel": ("bert", "BertEmbeddingModel"),
135
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
136
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
137
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
138
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
139
    "GritLM": ("gritlm", "GritLM"),
140
141
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
142
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
143
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
144
    "LlamaModel": ("llama", "LlamaForCausalLM"),
145
146
147
148
149
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
150
    "MistralModel": ("llama", "LlamaForCausalLM"),
151
    "ModernBertModel": ("modernbert", "ModernBertModel"),
152
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
153
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
154
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
155
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
156
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
157
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
158
159
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
160
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
161
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
162
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
163
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
164
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
165
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
166
167
168
169
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
170
171
}

172
173
174
175
176
177
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
178
179
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
180
    # [Auto-converted (see adapters.py)]
181
    "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
182
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
183
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
184
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
185
186
}

187
_MULTIMODAL_MODELS = {
188
    # [Decoder-only]
189
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
190
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
191
192
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
193
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
194
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
195
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
196
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
197
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
198
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
199
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
200
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
201
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
202
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
203
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
204
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
205
206
207
208
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
209
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
210
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
211
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
212
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
213
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
214
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
215
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
216
    "Ovis": ("ovis", "Ovis"),
217
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
218
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
219
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
220
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
221
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
222
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
223
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
224
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
225
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
226
    "UltravoxModel": ("ultravox", "UltravoxModel"),
227
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
228
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
229
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
230
    # [Encoder-decoder]
231
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
232
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
233
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
234
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
235
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
236
}
237
238

_SPECULATIVE_DECODING_MODELS = {
239
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
240
    "EAGLEModel": ("eagle", "EAGLE"),
241
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
242
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
243
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
244
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
245
246
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
247
}
248

249
_TRANSFORMERS_MODELS = {
250
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
251
}
252
# yapf: enable
253

254
_VLLM_MODELS = {
255
    **_TEXT_GENERATION_MODELS,
256
    **_EMBEDDING_MODELS,
257
    **_CROSS_ENCODER_MODELS,
258
    **_MULTIMODAL_MODELS,
259
    **_SPECULATIVE_DECODING_MODELS,
260
    **_TRANSFORMERS_MODELS,
261
262
}

263
264
265
266
267
268
269
270
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

271

272
273
@dataclass(frozen=True)
class _ModelInfo:
274
    architecture: str
275
    is_text_generation_model: bool
276
    is_pooling_model: bool
277
    supports_cross_encoding: bool
278
279
    supports_multimodal: bool
    supports_pp: bool
280
281
    has_inner_state: bool
    is_attention_free: bool
282
    is_hybrid: bool
283
    has_noops: bool
284
    supports_transcription: bool
285
    supports_v0_only: bool
286
287

    @staticmethod
288
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
289
        return _ModelInfo(
290
            architecture=model.__name__,
291
            is_text_generation_model=is_text_generation_model(model),
292
            is_pooling_model=True,  # Can convert any model into a pooling model
293
            supports_cross_encoding=supports_cross_encoding(model),
294
295
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
296
297
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
298
            is_hybrid=is_hybrid(model),
299
300
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
301
            has_noops=has_noops(model),
302
        )
303
304


305
class _BaseRegisteredModel(ABC):
306

307
308
309
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
310

311
    @abstractmethod
312
    def load_model_cls(self) -> type[nn.Module]:
313
        raise NotImplementedError
314
315


316
317
318
319
320
321
322
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
323
    model_cls: type[nn.Module]
324
325

    @staticmethod
326
    def from_model_cls(model_cls: type[nn.Module]):
327
328
329
330
331
332
333
334
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

335
    def load_model_cls(self) -> type[nn.Module]:
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

352
    def load_model_cls(self) -> type[nn.Module]:
353
354
355
356
357
358
359
360
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
361
) -> Optional[type[nn.Module]]:
362
    from vllm.platforms import current_platform
363
    current_platform.verify_model_arch(model_arch)
364
365
366
367
368
369
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
370
371


372
373
374
375
376
377
378
379
380
381
382
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
383
384


385
386
387
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
388
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
389

390
    def get_supported_archs(self) -> Set[str]:
391
        return self.models.keys()
392

393
394
395
    def register_model(
        self,
        model_arch: str,
396
        model_cls: Union[type[nn.Module], str],
397
    ) -> None:
398
399
400
        """
        Register an external model to be used in vLLM.

401
        `model_cls` can be either:
402

403
        - A [`torch.nn.Module`][] class directly referencing the model.
404
        - A string in the format `<module>:<class>` which can be used to
405
406
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
407
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
408
        """
409
410
411
412
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

413
        if model_arch in self.models:
414
415
416
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
417
418
419
420
421
422
423
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
424

425
            model = _LazyRegisteredModel(*split_str)
426
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
427
            model = _RegisteredModel.from_model_cls(model_cls)
428
429
430
431
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
432

433
        self.models[model_arch] = model
434

435
    def _raise_for_unsupported(self, architectures: list[str]):
436
        all_supported_archs = self.get_supported_archs()
437

438
439
440
441
442
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

443
444
445
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
446

447
    def _try_load_model_cls(self,
448
                            model_arch: str) -> Optional[type[nn.Module]]:
449
450
        if model_arch not in self.models:
            return None
451

452
        return _try_load_model_cls(model_arch, self.models[model_arch])
453

454
455
456
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
457

458
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
459

460
461
    def _normalize_archs(
        self,
462
463
        architectures: Union[str, list[str]],
    ) -> list[str]:
464
465
466
467
468
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

469
470
471
472
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

473
        # make sure Transformers backend is put at the last as a fallback
474
        if len(normalized_arch) != len(architectures):
475
            normalized_arch.append("TransformersForCausalLM")
476
        return normalized_arch
477

478
479
    def inspect_model_cls(
        self,
480
481
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
482
        architectures = self._normalize_archs(architectures)
483

484
485
486
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
487
                return (model_info, arch)
488

489
        return self._raise_for_unsupported(architectures)
490

491
492
    def resolve_model_cls(
        self,
493
494
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
495
        architectures = self._normalize_archs(architectures)
496

497
498
499
500
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
501

502
        return self._raise_for_unsupported(architectures)
503

504
505
    def is_text_generation_model(
        self,
506
        architectures: Union[str, list[str]],
507
    ) -> bool:
508
509
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
510

511
    def is_pooling_model(
512
        self,
513
        architectures: Union[str, list[str]],
514
    ) -> bool:
515
        model_cls, _ = self.inspect_model_cls(architectures)
516
        return model_cls.is_pooling_model
517

518
519
    def is_cross_encoder_model(
        self,
520
        architectures: Union[str, list[str]],
521
    ) -> bool:
522
523
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
524

525
526
    def is_multimodal_model(
        self,
527
        architectures: Union[str, list[str]],
528
    ) -> bool:
529
530
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
531
532
533

    def is_pp_supported_model(
        self,
534
        architectures: Union[str, list[str]],
535
    ) -> bool:
536
537
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
538

539
540
    def model_has_inner_state(
        self,
541
        architectures: Union[str, list[str]],
542
543
544
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
545

546
547
    def is_attention_free_model(
        self,
548
        architectures: Union[str, list[str]],
549
550
551
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
552

553
554
    def is_hybrid_model(
        self,
555
        architectures: Union[str, list[str]],
556
557
558
559
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

560
561
    def is_noops_model(
        self,
562
        architectures: Union[str, list[str]],
563
564
565
566
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

567
568
    def is_transcription_model(
        self,
569
        architectures: Union[str, list[str]],
570
571
572
573
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

574
575
    def is_v1_compatible(
        self,
576
        architectures: Union[str, list[str]],
577
578
579
580
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

581
582

ModelRegistry = _ModelRegistry({
583
584
    model_arch:
    _LazyRegisteredModel(
585
586
587
588
589
590
591
592
593
594
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
595
596
597
598
599
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

600
        # `cloudpickle` allows pickling lambda functions directly
601
        import cloudpickle
602
        input_bytes = cloudpickle.dumps((fn, output_filepath))
603
604
605

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
606
607
608
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
609
610
611
612
613
614
615
616
617

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

618
        with open(output_filepath, "rb") as f:
619
620
621
622
623
624
625
626
627
628
629
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
630
631
632

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
633
634
635


if __name__ == "__main__":
636
    _run()