"vllm/benchmarks/latency.py" did not exist on "8aca27fa11bfe0539b72002761add7d990af325e"
registry.py 50.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
7

8
import importlib
9
import json
10
import os
11
import pickle
12
13
import subprocess
import sys
14
import tempfile
15
from abc import ABC, abstractmethod
16
from collections.abc import Callable, Set
17
from dataclasses import asdict, dataclass, field
18
from functools import lru_cache
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Any, TypeVar
21
22

import torch.nn as nn
23
import transformers
24

25
from vllm import envs
26
27
28
29
30
from vllm.config import (
    ModelConfig,
    iter_architecture_defaults,
    try_match_architecture_defaults,
)
31
from vllm.logger import init_logger
32
from vllm.logging_utils import logtime
33
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
34
from vllm.utils.hashing import safe_hash
35

36
37
if TYPE_CHECKING:
    from vllm.config.model import AttnTypeStr
38
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
39
40
else:
    AttnTypeStr = Any
41
42
    SequencePoolingType = Any
    TokenPoolingType = Any
43
44


45
46
47
48
49
from .interfaces import (
    has_inner_state,
    has_noops,
    is_attention_free,
    is_hybrid,
Patrick von Platen's avatar
Patrick von Platen committed
50
    requires_raw_input_tokens,
51
    supports_cross_encoding,
52
    supports_late_interaction,
53
    supports_mamba_prefix_caching,
54
55
56
57
58
59
60
    supports_multimodal,
    supports_multimodal_encoder_tp_data,
    supports_multimodal_raw_input_only,
    supports_pp,
    supports_transcription,
)
from .interfaces_base import (
61
    get_attn_type,
62
63
    get_default_seq_pooling_type,
    get_default_tok_pooling_type,
64
65
66
    is_pooling_model,
    is_text_generation_model,
)
67
68
69

logger = init_logger(__name__)

70
71
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
72
    "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
73
    "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
74
75
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
Raghav Ravishankar's avatar
Raghav Ravishankar committed
76
    "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
77
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
78
79
80
81
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
82
    "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
ant-yy's avatar
ant-yy committed
83
    "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
84
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
85
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
86
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
87
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
88
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
89
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
90
    "CwmForCausalLM": ("llama", "LlamaForCausalLM"),
91
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
92
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
93
    "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
94
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
95
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
96
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
97
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
98
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
99
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
100
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
101
    "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
Kyungmin Lee's avatar
Kyungmin Lee committed
102
    "ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
103
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
104
105
106
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
    "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
107
    "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
108
109
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
110
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
111
    "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
112
    "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
114
    "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
115
    "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
116
    "Glm4MoeLiteForCausalLM": ("glm4_moe_lite", "Glm4MoeLiteForCausalLM"),
Jee Jee Li's avatar
Jee Jee Li committed
117
    "GlmMoeDsaForCausalLM": ("deepseek_v2", "GlmMoeDsaForCausalLM"),
118
    "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
119
120
121
122
123
124
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
125
126
    "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),  # noqa: E501
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),  # noqa: E501
127
    "GritLM": ("gritlm", "GritLM"),
Bijaya Dangol's avatar
Bijaya Dangol committed
128
129
    "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
    "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
130
131
    "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
    "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
132
    "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
133
134
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
135
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
136
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
137
138
    "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
    "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
139
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
140
    "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
141
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
142
    "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),  # noqa: E501
143
    "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
Paul Pak's avatar
Paul Pak committed
144
    "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
145
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
146
    "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
147
148
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
XuruiYang's avatar
XuruiYang committed
149
    "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
150
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
151
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
152
153
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
154
155
156
    "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
    "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
157
    "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
158
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
159
    "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
160
161
162
163
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
164
    "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
165
    "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
166
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
Luis Vega's avatar
Luis Vega committed
167
    "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
168
    "NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
169
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
170
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
171
    "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
172
173
174
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
175
    "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
176
    "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
177
    "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
178
    "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
179
180
181
182
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
Shinichi Hemmi's avatar
Shinichi Hemmi committed
183
    "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
184
    "Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
185
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
186
187
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
188
189
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
190
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
191
    "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
Li Xie's avatar
Li Xie committed
192
    "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
Song's avatar
Song committed
193
    "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
csy0225's avatar
csy0225 committed
194
    "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
195
196
197
198
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
199
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
200
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
201
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
202
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
203
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
204
205
206
}

_EMBEDDING_MODELS = {
207
    # [Text-only]
208
    "BertModel": ("bert", "BertEmbeddingModel"),
209
    "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
210
    "HF_ColBERT": ("colbert", "ColBERTModel"),
211
212
    "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
    "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
213
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
214
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
215
    "Gemma3TextModel": ("gemma3", "Gemma3Model"),
216
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
217
    "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
218
    "GritLM": ("gritlm", "GritLM"),
219
220
    "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
    "GteNewModel": ("bert_with_rope", "GteNewModel"),
221
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
222
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
223
    "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
224
    "LlamaModel": ("llama", "LlamaForCausalLM"),
225
226
    **{
        # Multiple models share the same architecture, so we include them all
227
228
        k: (mod, arch)
        for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
229
230
        if arch == "LlamaForCausalLM"
    },
231
    "MistralModel": ("llama", "LlamaForCausalLM"),
232
    "ModernBertModel": ("modernbert", "ModernBertModel"),
233
    "NomicBertModel": ("bert_with_rope", "NomicBertModel"),
234
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
235
    "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
236
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
237
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
238
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
239
240
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
241
    "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
242
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
chengchengpei's avatar
chengchengpei committed
243
244
245
246
    "VoyageQwen3BidirectionalEmbedModel": (
        "voyage",
        "VoyageQwen3BidirectionalEmbedModel",
    ),
247
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
248
    "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
249
    # [Multimodal]
250
    "CLIPModel": ("clip", "CLIPEmbeddingModel"),
251
    "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
252
253
254
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
255
    ),
256
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
257
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
258
259
    "ColQwen3": ("colqwen3", "ColQwen3Model"),
    "OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
260
    "Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
261
    "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
262
263
    # Technically Terratorch models work on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
264
    # models for the time being.
265
266
    "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
    "Terratorch": ("terratorch", "Terratorch"),
267
268
}

269
270
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
271
    "BertForTokenClassification": ("bert", "BertForTokenClassification"),
272
273
274
275
    "GteNewForSequenceClassification": (
        "bert_with_rope",
        "GteNewForSequenceClassification",
    ),
276
277
278
279
280
    "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
    "LlamaBidirectionalForSequenceClassification": (
        "llama",
        "LlamaBidirectionalForSequenceClassification",
    ),
281
282
283
284
    "ModernBertForSequenceClassification": (
        "modernbert",
        "ModernBertForSequenceClassification",
    ),
285
286
287
288
    "ModernBertForTokenClassification": (
        "modernbert",
        "ModernBertForTokenClassification",
    ),
289
290
291
292
293
    "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": (
        "roberta",
        "RobertaForSequenceClassification",
    ),
294
295
}

296
_MULTIMODAL_MODELS = {
297
    # [Decoder-only]
298
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
299
300
301
302
    "AudioFlamingo3ForConditionalGeneration": (
        "audioflamingo3",
        "AudioFlamingo3ForConditionalGeneration",
    ),
303
304
305
306
    "MusicFlamingoForConditionalGeneration": (
        "musicflamingo",
        "MusicFlamingoForConditionalGeneration",
    ),
307
308
309
    "AyaVisionForConditionalGeneration": (
        "aya_vision",
        "AyaVisionForConditionalGeneration",
310
    ),
311
    "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
312
    "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
313
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
314
315
316
    "ChameleonForConditionalGeneration": (
        "chameleon",
        "ChameleonForConditionalGeneration",
317
    ),
318
319
320
    "Cohere2VisionForConditionalGeneration": (
        "cohere2_vision",
        "Cohere2VisionForConditionalGeneration",
321
    ),
322
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
323
    "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
RED's avatar
RED committed
324
    "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"),
Roger Wang's avatar
Roger Wang committed
325
    "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
326
327
328
329
    "Eagle2_5_VLForConditionalGeneration": (
        "eagle2_5_vl",
        "Eagle2_5_VLForConditionalGeneration",
    ),
330
331
332
    "Ernie4_5_VLMoeForConditionalGeneration": (
        "ernie45_vl",
        "Ernie4_5_VLMoeForConditionalGeneration",
333
    ),
334
    "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"),  # noqa: E501
335
336
337
338
    "FunAudioChatForConditionalGeneration": (
        "funaudiochat",
        "FunAudioChatForConditionalGeneration",
    ),
339
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
340
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
341
342
343
    "Gemma3nForConditionalGeneration": (
        "gemma3n_mm",
        "Gemma3nForConditionalGeneration",
344
    ),
345
    "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
346
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
347
348
349
    "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
    "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
    "GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"),  # noqa: E501
350
351
352
    "GraniteSpeechForConditionalGeneration": (
        "granite_speech",
        "GraniteSpeechForConditionalGeneration",
353
    ),
354
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
355
356
357
358
    "HunYuanVLForConditionalGeneration": (
        "hunyuan_vision",
        "HunYuanVLForConditionalGeneration",
    ),
ltd0924's avatar
ltd0924 committed
359
    "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
360
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
361
    "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
Zero's avatar
Zero committed
362
363
364
365
    "OpenCUAForConditionalGeneration": (
        "opencua",
        "OpenCUAForConditionalGeneration",
    ),
366
367
368
    "InternS1ForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
369
    ),
370
371
372
    "InternVLForConditionalGeneration": (
        "interns1",
        "InternS1ForConditionalGeneration",
373
    ),
zxy's avatar
zxy committed
374
375
376
377
    "InternS1ProForConditionalGeneration": (
        "interns1_pro",
        "InternS1ProForConditionalGeneration",
    ),
378
379
380
381
    "Idefics3ForConditionalGeneration": (
        "idefics3",
        "Idefics3ForConditionalGeneration",
    ),
oscardev256's avatar
oscardev256 committed
382
    "IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
383
    "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"),  # noqa: E501
384
    "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
385
    "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
386
387
388
    "KeyeVL1_5ForConditionalGeneration": (
        "keye_vl1_5",
        "KeyeVL1_5ForConditionalGeneration",
389
    ),
390
    "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
391
    "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
392
    "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"),  # noqa: E501
393
394
395
396
    "LightOnOCRForConditionalGeneration": (
        "lightonocr",
        "LightOnOCRForConditionalGeneration",
    ),
397
    "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
398
    "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
399
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
400
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
401
402
403
    "LlavaNextForConditionalGeneration": (
        "llava_next",
        "LlavaNextForConditionalGeneration",
404
    ),
405
406
407
    "LlavaNextVideoForConditionalGeneration": (
        "llava_next_video",
        "LlavaNextVideoForConditionalGeneration",
408
    ),
409
410
411
    "LlavaOnevisionForConditionalGeneration": (
        "llava_onevision",
        "LlavaOnevisionForConditionalGeneration",
412
    ),
413
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
414
    "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
415
416
417
    "MiniMaxVL01ForConditionalGeneration": (
        "minimax_vl_01",
        "MiniMaxVL01ForConditionalGeneration",
418
    ),
419
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
420
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
421
422
423
    "Mistral3ForConditionalGeneration": (
        "mistral3",
        "Mistral3ForConditionalGeneration",
424
    ),
425
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
426
    "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"),
427
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
428
429
430
431
    "OpenPanguVLForConditionalGeneration": (
        "openpangu_vl",
        "OpenPanguVLForConditionalGeneration",
    ),
432
    "Ovis": ("ovis", "Ovis"),
433
    "Ovis2_5": ("ovis2_5", "Ovis2_5"),
434
435
    "Ovis2_6ForCausalLM": ("ovis2_5", "Ovis2_5"),
    "Ovis2_6_MoeForCausalLM": ("ovis2_5", "Ovis2_5"),
436
437
438
439
    "PaddleOCRVLForConditionalGeneration": (
        "paddleocr_vl",
        "PaddleOCRVLForConditionalGeneration",
    ),
440
441
442
443
    "PaliGemmaForConditionalGeneration": (
        "paligemma",
        "PaliGemmaForConditionalGeneration",
    ),
444
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
445
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
446
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
447
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
448
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
449
450
451
    "Qwen2_5_VLForConditionalGeneration": (
        "qwen2_5_vl",
        "Qwen2_5_VLForConditionalGeneration",
452
    ),
453
454
455
    "Qwen2AudioForConditionalGeneration": (
        "qwen2_audio",
        "Qwen2AudioForConditionalGeneration",
456
    ),
457
458
459
    "Qwen2_5OmniModel": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
460
    ),
461
462
463
    "Qwen2_5OmniForConditionalGeneration": (
        "qwen2_5_omni_thinker",
        "Qwen2_5OmniThinkerForConditionalGeneration",
464
    ),
465
466
467
468
    "Qwen3OmniMoeForConditionalGeneration": (
        "qwen3_omni_moe_thinker",
        "Qwen3OmniMoeThinkerForConditionalGeneration",
    ),
Roger Wang's avatar
Roger Wang committed
469
470
471
472
    "Qwen3ASRForConditionalGeneration": (
        "qwen3_asr",
        "Qwen3ASRForConditionalGeneration",
    ),
473
474
475
476
    "Qwen3ASRRealtimeGeneration": (
        "qwen3_asr_realtime",
        "Qwen3ASRRealtimeGeneration",
    ),
477
    "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"),  # noqa: E501
478
479
480
    "Qwen3VLMoeForConditionalGeneration": (
        "qwen3_vl_moe",
        "Qwen3VLMoeForConditionalGeneration",
481
    ),
482
483
484
485
486
487
488
489
    "Qwen3_5ForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5ForConditionalGeneration",
    ),
    "Qwen3_5MoeForConditionalGeneration": (
        "qwen3_5",
        "Qwen3_5MoeForConditionalGeneration",
    ),
490
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
Song's avatar
Song committed
491
    "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"),  # noqa: E501
汪志鹏's avatar
汪志鹏 committed
492
    "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"),  # noqa: E501
493
494
495
    "Tarsier2ForConditionalGeneration": (
        "qwen2_vl",
        "Tarsier2ForConditionalGeneration",
496
    ),
497
    "UltravoxModel": ("ultravox", "UltravoxModel"),
Patrick von Platen's avatar
Patrick von Platen committed
498
    "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"),  # noqa: E501
499
    "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"),  # noqa: E501
500
    # [Encoder-decoder]
501
502
503
504
    "NemotronParseForConditionalGeneration": (
        "nemotron_parse",
        "NemotronParseForConditionalGeneration",
    ),
505
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
506
}
507
508

_SPECULATIVE_DECODING_MODELS = {
509
    "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
510
    "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
zhiweiz's avatar
zhiweiz committed
511
    "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
512
    "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
513
    "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
514
    "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
515
    "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
516
    "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
517
518
519
520
    "EagleMistralLarge3ForCausalLM": (
        "mistral_large_3_eagle",
        "EagleMistralLarge3ForCausalLM",
    ),
521
    "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
522
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
523
    "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Kyungmin Lee's avatar
Kyungmin Lee committed
524
    "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
XuruiYang's avatar
XuruiYang committed
525
    "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
Yuxuan Zhang's avatar
Yuxuan Zhang committed
526
    "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
527
    "Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
528
    "GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
529
    "MedusaModel": ("medusa", "Medusa"),
530
    "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
531
    "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
csy0225's avatar
csy0225 committed
532
    "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
533
534
    "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
    "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
535
536
537
    # Temporarily disabled.
    # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
    # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
538
}
539

540
_TRANSFORMERS_SUPPORTED_MODELS = {
541
542
543
    # Text generation models
    "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
    # Multimodal models
544
545
546
547
    "Emu3ForConditionalGeneration": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
548
549
550
}

_TRANSFORMERS_BACKEND_MODELS = {
551
    # Text generation models
552
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
    # Multimodal models
    "TransformersMultiModalForCausalLM": (
        "transformers",
        "TransformersMultiModalForCausalLM",
    ),
    "TransformersMultiModalMoEForCausalLM": (
        "transformers",
        "TransformersMultiModalMoEForCausalLM",
    ),
    # Embedding models
    "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
    "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
    "TransformersMultiModalEmbeddingModel": (
        "transformers",
        "TransformersMultiModalEmbeddingModel",
    ),
    # Sequence classification models
571
    "TransformersForSequenceClassification": (
572
        "transformers",
573
        "TransformersForSequenceClassification",
574
    ),
575
    "TransformersMoEForSequenceClassification": (
576
        "transformers",
577
        "TransformersMoEForSequenceClassification",
578
    ),
579
580
581
    "TransformersMultiModalForSequenceClassification": (
        "transformers",
        "TransformersMultiModalForSequenceClassification",
582
    ),
583
}
584

585
_VLLM_MODELS = {
586
    **_TEXT_GENERATION_MODELS,
587
    **_EMBEDDING_MODELS,
588
    **_CROSS_ENCODER_MODELS,
589
    **_MULTIMODAL_MODELS,
590
    **_SPECULATIVE_DECODING_MODELS,
591
592
    **_TRANSFORMERS_SUPPORTED_MODELS,
    **_TRANSFORMERS_BACKEND_MODELS,
593
594
}

595
596
597
598
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
599
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
600

601
_PREVIOUSLY_SUPPORTED_MODELS = {
602
    "MotifForCausalLM": "0.10.2",
603
    "Phi3SmallForCausalLM": "0.9.2",
604
    "Phi4FlashForCausalLM": "0.10.2",
605
    "Phi4MultimodalForCausalLM": "0.12.0",
606
607
608
609
610
611
612
613
614
    # encoder-decoder models except whisper
    # have been removed for V0 deprecation.
    "BartModel": "0.10.2",
    "BartForConditionalGeneration": "0.10.2",
    "DonutForConditionalGeneration": "0.10.2",
    "Florence2ForConditionalGeneration": "0.10.2",
    "MBartForConditionalGeneration": "0.10.2",
    "MllamaForConditionalGeneration": "0.10.2",
}
615

616

617
618
@dataclass(frozen=True)
class _ModelInfo:
619
    architecture: str
620
    is_text_generation_model: bool
621
    is_pooling_model: bool
622
    attn_type: AttnTypeStr
623
624
    default_seq_pooling_type: SequencePoolingType
    default_tok_pooling_type: TokenPoolingType
625
    supports_cross_encoding: bool
626
    supports_late_interaction: bool
627
    supports_multimodal: bool
628
    supports_multimodal_raw_input_only: bool
Patrick von Platen's avatar
Patrick von Platen committed
629
    requires_raw_input_tokens: bool
630
    supports_multimodal_encoder_tp_data: bool
631
    supports_pp: bool
632
633
    has_inner_state: bool
    is_attention_free: bool
634
    is_hybrid: bool
635
    has_noops: bool
636
    supports_mamba_prefix_caching: bool
637
    supports_transcription: bool
638
    supports_transcription_only: bool
639
640

    @staticmethod
641
    def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
642
        return _ModelInfo(
643
            architecture=model.__name__,
644
            is_text_generation_model=is_text_generation_model(model),
645
            is_pooling_model=is_pooling_model(model),
646
647
            default_seq_pooling_type=get_default_seq_pooling_type(model),
            default_tok_pooling_type=get_default_tok_pooling_type(model),
648
            attn_type=get_attn_type(model),
649
            supports_cross_encoding=supports_cross_encoding(model),
650
            supports_late_interaction=supports_late_interaction(model),
651
            supports_multimodal=supports_multimodal(model),
652
653
654
            supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
                model
            ),
Patrick von Platen's avatar
Patrick von Platen committed
655
            requires_raw_input_tokens=requires_raw_input_tokens(model),
656
657
658
            supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
                model
            ),
659
            supports_pp=supports_pp(model),
660
661
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
662
            is_hybrid=is_hybrid(model),
663
            supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
664
            supports_transcription=supports_transcription(model),
665
666
667
            supports_transcription_only=(
                supports_transcription(model) and model.supports_transcription_only
            ),
668
            has_noops=has_noops(model),
669
        )
670
671


672
673
674
675
class _BaseRegisteredModel(ABC):
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
676

677
    @abstractmethod
678
    def load_model_cls(self) -> type[nn.Module]:
679
        raise NotImplementedError
680
681


682
683
684
685
686
687
688
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
689
    model_cls: type[nn.Module]
690
691

    @staticmethod
692
    def from_model_cls(model_cls: type[nn.Module]):
693
694
695
696
697
698
699
700
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

701
    def load_model_cls(self) -> type[nn.Module]:
702
703
704
705
706
707
708
709
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
710

711
712
713
    module_name: str
    class_name: str

714
715
716
717
718
719
720
721
    @staticmethod
    def _get_cache_dir() -> Path:
        return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"

    def _get_cache_filename(self) -> str:
        cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
        return f"{cls_name}.json"

722
    def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
723
724
        try:
            try:
725
                modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
726
727
728
                with open(modelinfo_path, encoding="utf-8") as file:
                    mi_dict = json.load(file)
            except FileNotFoundError:
729
                logger.debug(
730
                    "Cached model info file for class %s.%s not found",
731
732
733
                    self.module_name,
                    self.class_name,
                )
734
735
736
                return None

            if mi_dict["hash"] != module_hash:
737
                logger.debug(
738
                    "Cached model info file for class %s.%s is stale",
739
740
741
                    self.module_name,
                    self.class_name,
                )
742
743
744
745
746
                return None

            # file not changed, use cached _ModelInfo properties
            return _ModelInfo(**mi_dict["modelinfo"])
        except Exception:
747
            logger.debug(
748
                "Cached model info for class %s.%s error. ",
749
750
751
                self.module_name,
                self.class_name,
            )
752
753
            return None

754
    def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
755
756
        """save dictionary json file to cache"""
        from vllm.model_executor.model_loader.weight_utils import atomic_writer
757

758
759
760
761
762
763
764
765
        try:
            modelinfo_dict = {
                "hash": module_hash,
                "modelinfo": asdict(mi),
            }
            cache_dir = self._get_cache_dir()
            cache_dir.mkdir(parents=True, exist_ok=True)
            modelinfo_path = cache_dir / self._get_cache_filename()
766
            with atomic_writer(modelinfo_path, encoding="utf-8") as f:
767
768
769
770
771
                json.dump(modelinfo_dict, f, indent=2)
        except Exception:
            logger.exception("Error saving model info cache.")

    @logtime(logger=logger, msg="Registry inspect model class")
772
    def inspect_model_cls(self) -> _ModelInfo:
773
        model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
774
        module_hash = None
775

776
777
        if model_path.exists():
            with open(model_path, "rb") as f:
778
                module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
779
780
781

            mi = self._load_modelinfo_from_cache(module_hash)
            if mi is not None:
782
                logger.debug(
783
                    "Loaded model info for class %s.%s from cache",
784
785
786
                    self.module_name,
                    self.class_name,
                )
787
788
                return mi
            else:
789
                logger.debug(
790
                    "Cache model info for class %s.%s miss. Loading model instead.",
791
792
793
                    self.module_name,
                    self.class_name,
                )
794
795
796

        # Performed in another process to avoid initializing CUDA
        mi = _run_in_subprocess(
797
798
799
800
801
            lambda: _ModelInfo.from_model_cls(self.load_model_cls())
        )
        logger.debug(
            "Loaded model info for class %s.%s", self.module_name, self.class_name
        )
802
803

        # save cache file
804
805
        if module_hash is not None:
            self._save_modelinfo_to_cache(mi, module_hash)
806
807

        return mi
808

809
    def load_model_cls(self) -> type[nn.Module]:
810
811
812
813
814
815
816
817
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
818
) -> type[nn.Module] | None:
819
    from vllm.platforms import current_platform
820

821
    current_platform.verify_model_arch(model_arch)
822
823
824
    try:
        return model.load_model_cls()
    except Exception:
825
        logger.exception("Error in loading model architecture '%s'", model_arch)
826
        return None
827
828


829
830
831
832
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
833
) -> _ModelInfo | None:
834
835
836
    try:
        return model.inspect_model_cls()
    except Exception:
837
        logger.exception("Error in inspecting model architecture '%s'", model_arch)
838
        return None
839
840


841
842
843
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
844
    models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
845

846
    def get_supported_archs(self) -> Set[str]:
847
        return self.models.keys()
848

849
850
851
    def register_model(
        self,
        model_arch: str,
852
        model_cls: type[nn.Module] | str,
853
    ) -> None:
854
855
856
        """
        Register an external model to be used in vLLM.

857
        `model_cls` can be either:
858

859
        - A [`torch.nn.Module`][] class directly referencing the model.
860
        - A string in the format `<module>:<class>` which can be used to
861
862
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
863
          `RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
864
        """
865
866
867
868
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

869
        if model_arch in self.models:
870
871
            logger.warning(
                "Model architecture %s is already registered, and will be "
872
873
874
875
                "overwritten by the new model class %s.",
                model_arch,
                model_cls,
            )
876
877
878
879
880
881

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
882

883
            model = _LazyRegisteredModel(*split_str)
884
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
885
            model = _RegisteredModel.from_model_cls(model_cls)
886
        else:
887
888
889
890
            msg = (
                "`model_cls` should be a string or PyTorch model class, "
                f"not a {type(model_arch)}"
            )
891
            raise TypeError(msg)
892

893
        self.models[model_arch] = model
894

895
    def _raise_for_unsupported(self, architectures: list[str]):
896
        all_supported_archs = self.get_supported_archs()
897

898
899
900
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
901
902
                "to be inspected. Please check the logs for more details."
            )
903

904
905
906
907
908
909
910
911
        for arch in architectures:
            if arch in _PREVIOUSLY_SUPPORTED_MODELS:
                previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]

                raise ValueError(
                    f"Model architecture {arch} was supported in vLLM until "
                    f"v{previous_version}, and is not supported anymore. "
                    "Please use an older version of vLLM if you want to "
912
913
                    "use this model architecture."
                )
914

915
916
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
917
918
            f"Supported architectures: {all_supported_archs}"
        )
919

920
    def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
921
922
        if model_arch not in self.models:
            return None
923

924
        return _try_load_model_cls(model_arch, self.models[model_arch])
925

926
    def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
927
928
        if model_arch not in self.models:
            return None
929

930
931
932
933
934
935
        return _try_inspect_model_cls(model_arch, self.models[model_arch])

    def _try_resolve_transformers(
        self,
        architecture: str,
        model_config: ModelConfig,
936
    ) -> str | None:
937
938
939
        if architecture in _TRANSFORMERS_BACKEND_MODELS:
            return architecture

940
941
942
        auto_map: dict[str, str] = (
            getattr(model_config.hf_config, "auto_map", None) or dict()
        )
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958

        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        for prefix in ("AutoConfig", "AutoModel"):
            for name, module in auto_map.items():
                if name.startswith(prefix):
                    try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
959
                        trust_remote_code=model_config.trust_remote_code,
960
961
962
963
964
965
966
967
968
969
970
971
                        warn_on_fail=False,
                    )

        model_module = getattr(transformers, architecture, None)

        if model_module is None:
            for name, module in auto_map.items():
                if name.startswith("AutoModel"):
                    model_module = try_get_class_from_dynamic_module(
                        module,
                        model_config.model,
                        revision=model_config.revision,
972
                        trust_remote_code=model_config.trust_remote_code,
973
974
975
976
977
                        warn_on_fail=True,
                    )
                    if model_module is not None:
                        break
            else:
978
                if model_config.model_impl != "transformers":
979
980
981
982
983
984
985
                    return None

                raise ValueError(
                    f"Cannot find model module. {architecture!r} is not a "
                    "registered model in the Transformers library (only "
                    "relevant if the model is meant to be in Transformers) "
                    "and 'AutoModel' is not present in the model config's "
986
987
                    "'auto_map' (relevant if the model is custom)."
                )
988
989

        if not model_module.is_backend_compatible():
990
            if model_config.model_impl != "transformers":
991
                return None
992

993
994
            raise ValueError(
                f"The Transformers implementation of {architecture!r} "
995
996
                "is not compatible with vLLM."
            )
997

998
        return model_config._get_transformers_backend_cls()
999

1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    def _normalize_arch(
        self,
        architecture: str,
        model_config: ModelConfig,
    ) -> str:
        if architecture in self.models:
            return architecture

        # This may be called in order to resolve runner_type and convert_type
        # in the first place, in which case we consider the default match
        match = try_match_architecture_defaults(
            architecture,
            runner_type=getattr(model_config, "runner_type", None),
            convert_type=getattr(model_config, "convert_type", None),
        )
        if match:
            suffix, _ = match

            # Get the name of the base model to convert
            for repl_suffix, _ in iter_architecture_defaults():
                base_arch = architecture.replace(suffix, repl_suffix)
                if base_arch in self.models:
                    return base_arch

        return architecture
1025

1026
1027
    def inspect_model_cls(
        self,
1028
        architectures: str | list[str],
1029
        model_config: ModelConfig,
1030
    ) -> tuple[_ModelInfo, str]:
1031
1032
        if isinstance(architectures, str):
            architectures = [architectures]
1033
1034
        if not architectures:
            raise ValueError("No model architectures are specified")
1035
1036

        # Require transformers impl
1037
        if model_config.model_impl == "transformers":
1038
            arch = self._try_resolve_transformers(architectures[0], model_config)
1039
1040
1041
1042
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)
1043
        elif model_config.model_impl == "terratorch":
1044
1045
            model_info = self._try_inspect_model_cls("Terratorch")
            return (model_info, "Terratorch")
1046

1047
        # Fallback to transformers impl (after resolving convert_type)
1048
1049
1050
1051
1052
1053
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1054
1055
1056
1057
1058
1059
1060
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1061
            model_info = self._try_inspect_model_cls(normalized_arch)
1062
            if model_info is not None:
1063
                return (model_info, arch)
1064

1065
        # Fallback to transformers impl (before resolving runner_type)
1066
1067
1068
1069
1070
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1071
1072
1073
1074
1075
            if arch is not None:
                model_info = self._try_inspect_model_cls(arch)
                if model_info is not None:
                    return (model_info, arch)

1076
        return self._raise_for_unsupported(architectures)
1077

1078
1079
    def resolve_model_cls(
        self,
1080
        architectures: str | list[str],
1081
        model_config: ModelConfig,
1082
    ) -> tuple[type[nn.Module], str]:
1083
1084
        if isinstance(architectures, str):
            architectures = [architectures]
1085
1086
        if not architectures:
            raise ValueError("No model architectures are specified")
1087
1088

        # Require transformers impl
1089
        if model_config.model_impl == "transformers":
1090
            arch = self._try_resolve_transformers(architectures[0], model_config)
1091
1092
1093
1094
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)
1095
        elif model_config.model_impl == "terratorch":
1096
1097
1098
1099
            arch = "Terratorch"
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
1100

1101
        # Fallback to transformers impl (after resolving convert_type)
1102
1103
1104
1105
1106
1107
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
            and getattr(model_config, "convert_type", "none") == "none"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1108
1109
1110
1111
1112
1113
1114
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

        for arch in architectures:
            normalized_arch = self._normalize_arch(arch, model_config)
1115
            model_cls = self._try_load_model_cls(normalized_arch)
1116
1117
            if model_cls is not None:
                return (model_cls, arch)
1118

1119
        # Fallback to transformers impl (before resolving runner_type)
1120
1121
1122
1123
1124
        if (
            all(arch not in self.models for arch in architectures)
            and model_config.model_impl == "auto"
        ):
            arch = self._try_resolve_transformers(architectures[0], model_config)
1125
1126
1127
1128
1129
            if arch is not None:
                model_cls = self._try_load_model_cls(arch)
                if model_cls is not None:
                    return (model_cls, arch)

1130
        return self._raise_for_unsupported(architectures)
1131

1132
1133
    def is_text_generation_model(
        self,
1134
        architectures: str | list[str],
1135
        model_config: ModelConfig,
1136
    ) -> bool:
1137
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1138
        return model_cls.is_text_generation_model
1139

1140
    def is_pooling_model(
1141
        self,
1142
        architectures: str | list[str],
1143
        model_config: ModelConfig,
1144
    ) -> bool:
1145
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1146
        return model_cls.is_pooling_model
1147

1148
1149
    def is_cross_encoder_model(
        self,
1150
        architectures: str | list[str],
1151
        model_config: ModelConfig,
1152
    ) -> bool:
1153
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1154
        return model_cls.supports_cross_encoding
1155

1156
1157
    def is_multimodal_model(
        self,
1158
        architectures: str | list[str],
1159
        model_config: ModelConfig,
1160
    ) -> bool:
1161
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1162
        return model_cls.supports_multimodal
1163

1164
    def is_multimodal_raw_input_only_model(
1165
        self,
1166
        architectures: str | list[str],
1167
        model_config: ModelConfig,
1168
    ) -> bool:
1169
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1170
        return model_cls.supports_multimodal_raw_input_only
1171

1172
1173
    def is_pp_supported_model(
        self,
1174
        architectures: str | list[str],
1175
        model_config: ModelConfig,
1176
    ) -> bool:
1177
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1178
        return model_cls.supports_pp
1179

1180
1181
    def model_has_inner_state(
        self,
1182
        architectures: str | list[str],
1183
        model_config: ModelConfig,
1184
    ) -> bool:
1185
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1186
        return model_cls.has_inner_state
1187

1188
1189
    def is_attention_free_model(
        self,
1190
        architectures: str | list[str],
1191
        model_config: ModelConfig,
1192
    ) -> bool:
1193
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1194
        return model_cls.is_attention_free
1195

1196
1197
    def is_hybrid_model(
        self,
1198
        architectures: str | list[str],
1199
        model_config: ModelConfig,
1200
    ) -> bool:
1201
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1202
1203
        return model_cls.is_hybrid

1204
1205
    def is_noops_model(
        self,
1206
        architectures: str | list[str],
1207
        model_config: ModelConfig,
1208
    ) -> bool:
1209
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1210
1211
        return model_cls.has_noops

1212
1213
    def is_transcription_model(
        self,
1214
        architectures: str | list[str],
1215
        model_config: ModelConfig,
1216
    ) -> bool:
1217
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1218
1219
        return model_cls.supports_transcription

1220
1221
    def is_transcription_only_model(
        self,
1222
        architectures: str | list[str],
1223
        model_config: ModelConfig,
1224
    ) -> bool:
1225
        model_cls, _ = self.inspect_model_cls(architectures, model_config)
1226
1227
        return model_cls.supports_transcription_only

1228

1229
1230
1231
1232
1233
1234
1235
1236
1237
ModelRegistry = _ModelRegistry(
    {
        model_arch: _LazyRegisteredModel(
            module_name=f"vllm.model_executor.models.{mod_relname}",
            class_name=cls_name,
        )
        for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
    }
)
1238
1239
1240
1241
1242

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
1243
1244
1245
1246
1247
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

1248
        # `cloudpickle` allows pickling lambda functions directly
1249
        import cloudpickle
1250

1251
        input_bytes = cloudpickle.dumps((fn, output_filepath))
1252
1253
1254

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
1255
1256
1257
        returned = subprocess.run(
            _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
        )
1258
1259
1260
1261
1262
1263

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
1264
1265
1266
            raise RuntimeError(
                f"Error raised in subprocess:\n{returned.stderr.decode()}"
            ) from e
1267

1268
        with open(output_filepath, "rb") as f:
1269
1270
1271
1272
1273
1274
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
1275

1276
1277
1278
1279
1280
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
1281
1282
1283

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
1284
1285
1286


if __name__ == "__main__":
1287
    _run()