registry.py 22.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
13
14
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
15
16
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
17

18
import cloudpickle
19
20
21
import torch.nn as nn

from vllm.logger import init_logger
22
from vllm.utils import is_in_doc_build
23

24
25
26
27
from .interfaces import (has_inner_state, has_noops, is_attention_free,
                         is_hybrid, supports_cross_encoding,
                         supports_multimodal, supports_pp,
                         supports_transcription, supports_v0_only)
28
from .interfaces_base import is_text_generation_model
29
30
31

logger = init_logger(__name__)

32
# yapf: disable
33
34
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
35
36
37
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
38
39
40
41
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
42
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
43
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
44
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
45
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
46
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
47
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
48
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
49
50
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
51
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
52
53
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
54
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
55
56
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
57
    "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
58
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
59
60
61
62
63
64
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
65
    "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"),   # noqa: E501
66
    "GritLM": ("gritlm", "GritLM"),
Michael Goin's avatar
Michael Goin committed
67
    "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
68
69
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
70
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
71
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
72
73
74
75
76
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
77
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
78
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
79
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
80
81
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
82
83
84
85
86
87
88
89
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
90
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
91
92
93
94
95
96
97
98
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
99
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
100
101
102
103
104
105
106
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
107
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
108
    "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
109
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
110
    "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
111
112
113
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
114
115
116
}

_EMBEDDING_MODELS = {
117
    # [Text-only]
118
    "BertModel": ("bert", "BertEmbeddingModel"),
119
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
120
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
121
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
122
    "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
123
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
124
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
125
    "GritLM": ("gritlm", "GritLM"),
126
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
127
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
128
    "LlamaModel": ("llama", "LlamaForCausalLM"),
129
130
131
132
133
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
134
    "MistralModel": ("llama", "LlamaForCausalLM"),
135
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
136
137
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
138
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
139
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
140
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
141
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
142
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
143
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
144
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
145
146
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
147
148
149
150
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
151
152
}

153
154
155
156
157
158
159
160
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

161
_MULTIMODAL_MODELS = {
162
    # [Decoder-only]
163
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
164
165
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
166
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
167
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
168
    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
169
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
170
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
171
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
172
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
173
174
175
176
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
177
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
178
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
179
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
180
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
181
182
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
183
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
184
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
185
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
186
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
187
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
188
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
189
    "UltravoxModel": ("ultravox", "UltravoxModel"),
190
    "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
191
    # [Encoder-decoder]
192
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
193
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
194
    "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
195
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
196
}
197
198
199

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
200
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
201
202
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
203
}
204

205
_TRANSFORMERS_MODELS = {
206
    "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
207
}
208
# yapf: enable
209

210
_VLLM_MODELS = {
211
    **_TEXT_GENERATION_MODELS,
212
    **_EMBEDDING_MODELS,
213
    **_CROSS_ENCODER_MODELS,
214
    **_MULTIMODAL_MODELS,
215
    **_SPECULATIVE_DECODING_MODELS,
216
    **_TRANSFORMERS_MODELS,
217
218
}

219
220
221
222
223
224
225
226
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

227

228
229
@dataclass(frozen=True)
class _ModelInfo:
230
    architecture: str
231
    is_text_generation_model: bool
232
    is_pooling_model: bool
233
    supports_cross_encoding: bool
234
235
    supports_multimodal: bool
    supports_pp: bool
236
237
    has_inner_state: bool
    is_attention_free: bool
238
    is_hybrid: bool
239
    has_noops: bool
240
    supports_transcription: bool
241
    supports_v0_only: bool
242
243

    @staticmethod
244
245
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
246
            architecture=model.__name__,
247
            is_text_generation_model=is_text_generation_model(model),
248
            is_pooling_model=True,  # Can convert any model into a pooling model
249
            supports_cross_encoding=supports_cross_encoding(model),
250
251
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
252
253
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
254
            is_hybrid=is_hybrid(model),
255
256
            supports_transcription=supports_transcription(model),
            supports_v0_only=supports_v0_only(model),
257
            has_noops=has_noops(model),
258
        )
259
260


261
class _BaseRegisteredModel(ABC):
262

263
264
265
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
266

267
268
269
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
270
271


272
273
274
275
276
277
278
279
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
280
281

    @staticmethod
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
318
    from vllm.platforms import current_platform
319
    current_platform.verify_model_arch(model_arch)
320
321
322
323
324
325
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
326
327


328
329
330
331
332
333
334
335
336
337
338
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
339
340


341
342
343
344
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
345

346
347
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
348

349
350
351
352
353
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
354
355
356
357
358
359
360
361
362
363
364
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
365
366
367
368
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

369
        if model_arch in self.models:
370
371
372
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
373
374
375
376
377
378
379
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
380

381
            model = _LazyRegisteredModel(*split_str)
382
383
        elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass(
                model_cls, nn.Module)):
384
            model = _RegisteredModel.from_model_cls(model_cls)
385
386
387
388
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
389

390
        self.models[model_arch] = model
391

392
393
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
394

395
396
397
398
399
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

400
401
402
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
403

404
405
406
407
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
408

409
        return _try_load_model_cls(model_arch, self.models[model_arch])
410

411
412
413
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
414

415
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
416

417
418
419
420
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
421
422
423
424
425
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

426
427
428
429
        # filter out support architectures
        normalized_arch = list(
            filter(lambda model: model in self.models, architectures))

430
        # make sure Transformers backend is put at the last as a fallback
431
        if len(normalized_arch) != len(architectures):
432
            normalized_arch.append("TransformersForCausalLM")
433
        return normalized_arch
434

435
436
437
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
438
    ) -> Tuple[_ModelInfo, str]:
439
        architectures = self._normalize_archs(architectures)
440

441
442
443
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
444
                return (model_info, arch)
445

446
        return self._raise_for_unsupported(architectures)
447

448
449
450
451
452
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
453

454
455
456
457
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
458

459
        return self._raise_for_unsupported(architectures)
460

461
462
463
464
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
465
466
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
467

468
    def is_pooling_model(
469
470
471
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
472
        model_cls, _ = self.inspect_model_cls(architectures)
473
        return model_cls.is_pooling_model
474

475
476
477
478
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
479
480
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
481

482
483
484
485
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
486
487
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
488
489
490
491
492

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
493
494
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
495

496
497
498
499
500
501
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
502

503
504
505
506
507
508
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
509

510
511
512
513
514
515
516
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

517
518
519
520
521
522
523
    def is_noops_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_noops

524
525
526
527
528
529
530
    def is_transcription_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

531
532
533
534
535
536
537
    def is_v1_compatible(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return not model_cls.supports_v0_only

538
539

ModelRegistry = _ModelRegistry({
540
541
    model_arch:
    _LazyRegisteredModel(
542
543
544
545
546
547
548
549
550
551
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
552
553
554
555
556
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

557
        # `cloudpickle` allows pickling lambda functions directly
558
        input_bytes = cloudpickle.dumps((fn, output_filepath))
559
560
561

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
562
563
564
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
565
566
567
568
569
570
571
572
573

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

574
        with open(output_filepath, "rb") as f:
575
576
577
578
579
580
581
582
583
584
585
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
586
587
588

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
589
590
591


if __name__ == "__main__":
592
    _run()