registry.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7
import importlib
8
import os
9
import pickle
10
11
import subprocess
import sys
12
import tempfile
13
from abc import ABC, abstractmethod
14
from collections.abc import Set
15
16
from dataclasses import dataclass, field
from functools import lru_cache
17
from typing import Callable, Optional, TypeVar, Union
18

19
import cloudpickle
20
21
22
23
import torch.nn as nn

from vllm.logger import init_logger

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
40
41
42
43
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
45
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
46
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
47
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
48
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
49
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
50
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
51
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
52
53
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
54
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
55
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
56
57
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
58
59
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
60
    "FM9GForCausalLM": ("fm9g", "FM9GForCausalLM"),
61
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
62
63
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
64
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Robert Shaw's avatar
Robert Shaw committed
65
66
    #TODO(ywang96): Support multimodal gemma3n
    "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"),    # noqa: E501
67
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
68
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
69
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
70
71
72
73
74
75
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
76
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),   # noqa: E501
77
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
78
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
79
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
80
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
81
82
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
83
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
84
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
85
86
87
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
88
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
89
90
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
92
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
Dhia Eddine Rhaiem's avatar
Dhia Eddine Rhaiem committed
93
    "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
94
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
95
96
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
97
98
99
100
101
102
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
103
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
104
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
105
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
106
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
107
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
108
109
110
111
112
113
114
115
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
116
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
117
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
118
119
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
120
121
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
122
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
zhuwenwen's avatar
zhuwenwen committed
123
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
124
125
126
127
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
128
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
129
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
130
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
131
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
132
133
    "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
134
135
136
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
zhuwenwen's avatar
zhuwenwen committed
137
138
139
140
141
    # step model
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
    "Step2ForCausalLM": ("step1", "Step1ForCausalLM"),
    "Step1MoEForCausalLM": ("step1", "Step1ForCausalLM"),
    "Step2MiniForCausalLM": ("step2_mini", "Step2MiniForCausalLM"),
142
143
144
}

_EMBEDDING_MODELS = {
145
    # [Text-only]
146
    "BertModel": ("bert", "BertEmbeddingModel"),
147
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
148
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
149
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
150
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
151
    "GritLM": ("gritlm", "GritLM"),
152
153
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
154
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
155
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
156
    "LlamaModel": ("llama", "LlamaForCausalLM"),
157
158
159
160
161
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
162
    "MistralModel": ("llama", "LlamaForCausalLM"),
163
    "ModernBertModel": ("modernbert", "ModernBertModel"),
164
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
165
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
166
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
167
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
168
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
169
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
170
171
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
172
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
173
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
174
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
175
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
176
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
177
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
178
179
180
181
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
zhuwenwen's avatar
zhuwenwen committed
182
183
184
185
186
187
188
189
190
191
192
193
194
    # step model
    "Step1ForSequenceClassification": ("step1",
                                       "Step1ForSequenceClassification"),
    "Step2ForClassification": ("step1", "Step1ForSequenceClassification"),
    "Step2ForSequenceClassification": ("step2",
                                       "Step2ForSequenceClassification"),
    "Step2MiniForClassification": ("step2_mini",
                                   "Step2MiniForSequenceClassification"),
    "MMGPTQwen2RewardModel": ("mm_step1o", "MMGPTStep1oRewardModel"),
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
195
196
}

197
198
199
200
201
202
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
xsank's avatar
xsank committed
203
204
    "ModernBertForSequenceClassification": ("modernbert",
                                            "ModernBertForSequenceClassification"),
205
206
207
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
    "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"),  # noqa: E501
208
209
}

210
_MULTIMODAL_MODELS = {
211
    # [Decoder-only]
212
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
213
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
214
215
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
216
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
217
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
218
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
219
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
220
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),  # noqa: E501
221
    "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"),  # noqa: E501
222
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
223
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
224
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
225
    "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"),  # noqa: E501
226
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
227
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
228
229
230
231
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
232
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
233
    "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"),  # noqa: E501
234
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
235
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
236
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
237
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
238
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
239
    "Ovis": ("ovis", "Ovis"),
240
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
241
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
242
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
243
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
244
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
245
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
246
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
247
    "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
248
    "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"),  # noqa: E501
249
    "UltravoxModel": ("ultravox", "UltravoxModel"),
zhuwenwen's avatar
zhuwenwen committed
250
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
251
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
汪志鹏's avatar
汪志鹏 committed
252
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
253
    "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"),  # noqa: E501
254
    # [Encoder-decoder]
255
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
256
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
257
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
258
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
259
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
260
}
261
262

_SPECULATIVE_DECODING_MODELS = {
263
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
264
    "EAGLEModel": ("eagle", "EAGLE"),
265
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
266
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
267
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
268
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
zhuwenwen's avatar
zhuwenwen committed
269
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
270
271
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
zhuwenwen's avatar
zhuwenwen committed
272
273
274
275
276
277
278
279
280
    # step model
    "MMGPTStep1ForCausalLMV2": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV2"),
    "MMGPTStep1ForCausalLMV3": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV3"),
    "MMGPTStep1ForCausalLMV4": ("mm_step1o", "MMGPTStep1oForCausalLM"),
    "MMGPTQwen2ForCausalLM": ("mm_step1p5c_1u", "MMGPTStep1ForCausalLMV3"),
    "MMGPTQwen2ForCausalLMV2": ("mm_step1o", "MMGPTStep1oForCausalLM"),
    "MMGPTStep3vForCausalLM": ("mm_step1o", "MMGPTStep1oForCausalLM"),
    "Step1AudioForCausalLM": ("mm_step_audio", "MMGPTStep1fForCausalLM"),
    "StepAudioForCausalLMV2": ("mm_step_audio", "MMGPTStep1fForCausalLM"),
281
}
282

283
_TRANSFORMERS_MODELS = {
284
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
285
}
286
# yapf: enable
287

288
_VLLM_MODELS = {
289
    **_TEXT_GENERATION_MODELS,
290
    **_EMBEDDING_MODELS,
291
    **_CROSS_ENCODER_MODELS,
292
    **_MULTIMODAL_MODELS,
293
    **_SPECULATIVE_DECODING_MODELS,
294
    **_TRANSFORMERS_MODELS,
295
296
}

297
298
299
300
301
302
303
304
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

305

306
307
@dataclass(frozen=True)
class _ModelInfo:
308
    architecture: str
309
    is_text_generation_model: bool
310
    is_pooling_model: bool
311
    supports_cross_encoding: bool
312
313
    supports_multimodal: bool
    supports_pp: bool
314
315
    has_inner_state: bool
    is_attention_free: bool
316
    is_hybrid: bool
317
    has_noops: bool
318
    supports_transcription: bool
319
    supports_v0_only: bool
320
321

    @staticmethod
322
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
323
        return _ModelInfo(
324
            architecture=model.__name__,
325
            is_text_generation_model=is_text_generation_model(model),
326
            is_pooling_model=True,  # Can convert any model into a pooling model
327
            supports_cross_encoding=supports_cross_encoding(model),
328
329
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
330
331
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
332
            is_hybrid=is_hybrid(model),
333
334
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
335
            has_noops=has_noops(model),
336
        )
337
338


339
class _BaseRegisteredModel(ABC):
340

341
342
343
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
344

345
    @abstractmethod
346
    def load_model_cls(self) -> type[nn.Module]:
347
        raise NotImplementedError
348
349


350
351
352
353
354
355
356
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
357
    model_cls: type[nn.Module]
358
359

    @staticmethod
360
    def from_model_cls(model_cls: type[nn.Module]):
361
362
363
364
365
366
367
368
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

369
    def load_model_cls(self) -> type[nn.Module]:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

386
    def load_model_cls(self) -> type[nn.Module]:
387
388
389
390
391
392
393
394
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
395
) -> Optional[type[nn.Module]]:
396
    from vllm.platforms import current_platform
397
    current_platform.verify_model_arch(model_arch)
398
399
400
401
402
403
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
404
405


406
407
408
409
410
411
412
413
414
415
416
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
417
418


419
420
421
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
422
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
423

424
    def get_supported_archs(self) -> Set[str]:
425
        return self.models.keys()
426

427
428
429
    def register_model(
        self,
        model_arch: str,
430
        model_cls: Union[type[nn.Module], str],
431
    ) -> None:
432
433
434
        """
        Register an external model to be used in vLLM.

435
        `model_cls` can be either:
436

437
        - A [`torch.nn.Module`][] class directly referencing the model.
438
        - A string in the format `<module>:<class>` which can be used to
439
440
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
441
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
442
        """
443
444
445
446
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

447
        if model_arch in self.models:
448
449
450
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
451
452
453
454
455
456
457
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
458

459
            model = _LazyRegisteredModel(*split_str)
460
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
461
            model = _RegisteredModel.from_model_cls(model_cls)
462
463
464
465
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
466

467
        self.models[model_arch] = model
468

469
    def _raise_for_unsupported(self, architectures: list[str]):
470
        all_supported_archs = self.get_supported_archs()
471

472
473
474
475
476
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

477
478
479
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
480

481
    def _try_load_model_cls(self,
482
                            model_arch: str) -> Optional[type[nn.Module]]:
483
484
        if model_arch not in self.models:
            return None
485

486
        return _try_load_model_cls(model_arch, self.models[model_arch])
487

488
489
490
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
491

492
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
493

494
495
    def _normalize_archs(
        self,
496
497
        architectures: Union[str, list[str]],
    ) -> list[str]:
498
499
500
501
502
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

503
504
505
506
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

507
        # make sure Transformers backend is put at the last as a fallback
508
        if len(normalized_arch) != len(architectures):
509
            normalized_arch.append("TransformersForCausalLM")
510
        return normalized_arch
511

512
513
    def inspect_model_cls(
        self,
514
515
        architectures: Union[str, list[str]],
    ) -> tuple[_ModelInfo, str]:
516
        architectures = self._normalize_archs(architectures)
517

518
519
520
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
521
                return (model_info, arch)
522

523
        return self._raise_for_unsupported(architectures)
524

525
526
    def resolve_model_cls(
        self,
527
528
        architectures: Union[str, list[str]],
    ) -> tuple[type[nn.Module], str]:
529
        architectures = self._normalize_archs(architectures)
530

531
532
533
534
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
535

536
        return self._raise_for_unsupported(architectures)
537

538
539
    def is_text_generation_model(
        self,
540
        architectures: Union[str, list[str]],
541
    ) -> bool:
542
543
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
544

545
    def is_pooling_model(
546
        self,
547
        architectures: Union[str, list[str]],
548
    ) -> bool:
549
        model_cls, _ = self.inspect_model_cls(architectures)
550
        return model_cls.is_pooling_model
551

552
553
    def is_cross_encoder_model(
        self,
554
        architectures: Union[str, list[str]],
555
    ) -> bool:
556
557
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
558

559
560
    def is_multimodal_model(
        self,
561
        architectures: Union[str, list[str]],
562
    ) -> bool:
563
564
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
565
566
567

    def is_pp_supported_model(
        self,
568
        architectures: Union[str, list[str]],
569
    ) -> bool:
570
571
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
572

573
574
    def model_has_inner_state(
        self,
575
        architectures: Union[str, list[str]],
576
577
578
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
579

580
581
    def is_attention_free_model(
        self,
582
        architectures: Union[str, list[str]],
583
584
585
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
586

587
588
    def is_hybrid_model(
        self,
589
        architectures: Union[str, list[str]],
590
591
592
593
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

594
595
    def is_noops_model(
        self,
596
        architectures: Union[str, list[str]],
597
598
599
600
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

601
602
    def is_transcription_model(
        self,
603
        architectures: Union[str, list[str]],
604
605
606
607
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

608
609
    def is_v1_compatible(
        self,
610
        architectures: Union[str, list[str]],
611
612
613
614
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

615
616

ModelRegistry = _ModelRegistry({
617
618
    model_arch:
    _LazyRegisteredModel(
619
620
621
622
623
624
625
626
627
628
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
629
630
631
632
633
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

634
        # `cloudpickle` allows pickling lambda functions directly
635
        input_bytes = cloudpickle.dumps((fn, output_filepath))
636
637
638

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
639
640
641
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
642
643
644
645
646
647
648
649
650

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

651
        with open(output_filepath, "rb") as f:
652
653
654
655
656
657
658
659
660
661
662
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
663
664
665

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
666
667
668


if __name__ == "__main__":
669
    _run()