registry.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
13
14
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
15
16
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
17

18
import cloudpickle
19
20
21
import torch.nn as nn

from vllm.logger import init_logger
22
from vllm.utils import is_in_doc_build
23

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
    "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
39
40
41
42
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
43
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
44
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
45
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
46
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
49
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
50
51
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
52
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
53
54
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
55
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
56
57
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
58
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
59
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
60
61
62
63
64
65
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
66
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
67
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
68
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
69
70
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
71
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
72
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
73
74
75
76
77
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
78
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
79
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
80
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
81
82
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
83
84
85
86
87
88
89
90
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
91
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
92
93
94
95
96
97
98
99
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
100
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
101
102
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
103
104
    "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
    "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
105
106
107
108
109
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
110
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
111
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
112
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
113
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
114
115
116
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
117
118
119
}

_EMBEDDING_MODELS = {
120
    # [Text-only]
121
    "BertModel": ("bert", "BertEmbeddingModel"),
122
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
123
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
124
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
125
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
126
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
127
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
128
    "GritLM": ("gritlm", "GritLM"),
129
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
130
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
131
    "LlamaModel": ("llama", "LlamaForCausalLM"),
132
133
134
135
136
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
137
    "MistralModel": ("llama", "LlamaForCausalLM"),
138
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
139
140
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
141
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
142
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
143
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
144
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
145
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
146
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
147
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
148
149
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
150
151
152
153
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
154
155
}

156
157
158
159
160
161
162
163
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

164
_MULTIMODAL_MODELS = {
165
    # [Decoder-only]
166
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
Jennifer Zhao's avatar
Jennifer Zhao committed
167
    "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"),  # noqa: E501
168
169
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
170
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
171
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
172
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
173
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
174
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
175
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
176
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
177
178
179
180
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
181
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
182
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
183
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
184
    "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"),  # noqa: E501
185
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
186
187
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
188
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
189
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
190
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
191
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
192
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
193
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
194
    "UltravoxModel": ("ultravox", "UltravoxModel"),
195
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
196
    # [Encoder-decoder]
197
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
198
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
199
    "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"),  # noqa: E501
200
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
201
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
202
}
203
204
205

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
206
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
207
208
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
209
}
210

211
_TRANSFORMERS_MODELS = {
212
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
213
}
214
# yapf: enable
215

216
_VLLM_MODELS = {
217
    **_TEXT_GENERATION_MODELS,
218
    **_EMBEDDING_MODELS,
219
    **_CROSS_ENCODER_MODELS,
220
    **_MULTIMODAL_MODELS,
221
    **_SPECULATIVE_DECODING_MODELS,
222
    **_TRANSFORMERS_MODELS,
223
224
}

225
226
227
228
229
230
231
232
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

233

234
235
@dataclass(frozen=True)
class _ModelInfo:
236
    architecture: str
237
    is_text_generation_model: bool
238
    is_pooling_model: bool
239
    supports_cross_encoding: bool
240
241
    supports_multimodal: bool
    supports_pp: bool
242
243
    has_inner_state: bool
    is_attention_free: bool
244
    is_hybrid: bool
245
    has_noops: bool
246
    supports_transcription: bool
247
    supports_v0_only: bool
248
249

    @staticmethod
250
251
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
252
            architecture=model.__name__,
253
            is_text_generation_model=is_text_generation_model(model),
254
            is_pooling_model=True,  # Can convert any model into a pooling model
255
            supports_cross_encoding=supports_cross_encoding(model),
256
257
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
258
259
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
260
            is_hybrid=is_hybrid(model),
261
262
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
263
            has_noops=has_noops(model),
264
        )
265
266


267
class _BaseRegisteredModel(ABC):
268

269
270
271
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
272

273
274
275
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
276
277


278
279
280
281
282
283
284
285
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
286
287

    @staticmethod
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
324
    from vllm.platforms import current_platform
325
    current_platform.verify_model_arch(model_arch)
326
327
328
329
330
331
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
332
333


334
335
336
337
338
339
340
341
342
343
344
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
345
346


347
348
349
350
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
351

352
353
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
354

355
356
357
358
359
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
360
361
362
363
364
365
366
367
368
369
370
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
371
372
373
374
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

375
        if model_arch in self.models:
376
377
378
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
379
380
381
382
383
384
385
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
386

387
            model = _LazyRegisteredModel(*split_str)
388
389
        elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass(
                model_cls, nn.Module)):
390
            model = _RegisteredModel.from_model_cls(model_cls)
391
392
393
394
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
395

396
        self.models[model_arch] = model
397

398
399
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
400

401
402
403
404
405
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

406
407
408
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
409

410
411
412
413
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
414

415
        return _try_load_model_cls(model_arch, self.models[model_arch])
416

417
418
419
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
420

421
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
422

423
424
425
426
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
427
428
429
430
431
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

432
433
434
435
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

436
        # make sure Transformers backend is put at the last as a fallback
437
        if len(normalized_arch) != len(architectures):
438
            normalized_arch.append("TransformersForCausalLM")
439
        return normalized_arch
440

441
442
443
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
444
    ) -> Tuple[_ModelInfo, str]:
445
        architectures = self._normalize_archs(architectures)
446

447
448
449
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
450
                return (model_info, arch)
451

452
        return self._raise_for_unsupported(architectures)
453

454
455
456
457
458
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
459

460
461
462
463
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
464

465
        return self._raise_for_unsupported(architectures)
466

467
468
469
470
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
471
472
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
473

474
    def is_pooling_model(
475
476
477
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
478
        model_cls, _ = self.inspect_model_cls(architectures)
479
        return model_cls.is_pooling_model
480

481
482
483
484
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
485
486
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
487

488
489
490
491
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
492
493
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
494
495
496
497
498

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
499
500
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
501

502
503
504
505
506
507
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
508

509
510
511
512
513
514
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
515

516
517
518
519
520
521
522
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

523
524
525
526
527
528
529
    def is_noops_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

530
531
532
533
534
535
536
    def is_transcription_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

537
538
539
540
541
542
543
    def is_v1_compatible(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

544
545

ModelRegistry = _ModelRegistry({
546
547
    model_arch:
    _LazyRegisteredModel(
548
549
550
551
552
553
554
555
556
557
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
558
559
560
561
562
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

563
        # `cloudpickle` allows pickling lambda functions directly
564
        input_bytes = cloudpickle.dumps((fn, output_filepath))
565
566
567

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
568
569
570
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
571
572
573
574
575
576
577
578
579

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

580
        with open(output_filepath, "rb") as f:
581
582
583
584
585
586
587
588
589
590
591
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
592
593
594

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
595
596
597


if __name__ == "__main__":
598
    _run()