registry.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
13
14
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
15
16
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
24
                         supports_cross_encoding, supports_multimodal,
25
                         supports_pp, supports_transcription)
26
from .interfaces_base import is_text_generation_model
27
28
29

logger = init_logger(__name__)

30
# yapf: disable
31
32
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
33
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
36
37
38
39
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
40
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
41
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
42
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
43
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
44
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
45
46
47
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
49
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
50
51
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
52
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
53
54
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
55
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
56
57
58
59
60
61
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
62
    "GritLM": ("gritlm", "GritLM"),
63
64
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
65
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
66
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
67
68
69
70
71
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
72
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
73
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
74
    "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
75
76
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
77
78
79
80
81
82
83
84
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
85
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
86
87
88
89
90
91
92
93
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
94
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
95
96
97
98
99
100
101
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
102
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
103
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
104
105
106
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
107
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
108
109
110
}

_EMBEDDING_MODELS = {
111
    # [Text-only]
112
    "BertModel": ("bert", "BertEmbeddingModel"),
113
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
114
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
115
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
116
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
117
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
118
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
119
    "GritLM": ("gritlm", "GritLM"),
120
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
121
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
122
    "LlamaModel": ("llama", "LlamaForCausalLM"),
123
124
125
126
127
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
128
    "MistralModel": ("llama", "LlamaForCausalLM"),
129
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
130
131
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
132
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
133
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
134
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
135
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
136
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
137
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
138
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
139
140
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
141
142
143
144
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
145
146
}

147
148
149
150
151
152
153
154
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

155
_MULTIMODAL_MODELS = {
156
    # [Decoder-only]
157
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
158
159
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
160
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
161
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
162
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
163
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
164
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
165
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
166
167
168
169
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
170
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
171
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
172
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
173
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
174
175
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
176
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
177
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
178
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
179
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
180
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
181
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
182
    "UltravoxModel": ("ultravox", "UltravoxModel"),
183
184
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
185
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
186
}
187
188
189

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
190
    "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
191
192
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
193
}
194
195
196
197

_FALLBACK_MODEL = {
    "TransformersModel": ("transformers", "TransformersModel"),
}
198
# yapf: enable
199

200
_VLLM_MODELS = {
201
    **_TEXT_GENERATION_MODELS,
202
    **_EMBEDDING_MODELS,
203
    **_CROSS_ENCODER_MODELS,
204
    **_MULTIMODAL_MODELS,
205
    **_SPECULATIVE_DECODING_MODELS,
206
    **_FALLBACK_MODEL,
207
208
}

209
210
211
212
213
214
215
216
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

217

218
219
@dataclass(frozen=True)
class _ModelInfo:
220
    architecture: str
221
    is_text_generation_model: bool
222
    is_pooling_model: bool
223
    supports_cross_encoding: bool
224
225
    supports_multimodal: bool
    supports_pp: bool
226
227
    has_inner_state: bool
    is_attention_free: bool
228
    is_hybrid: bool
229
    supports_transcription: bool
230
231

    @staticmethod
232
233
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
234
            architecture=model.__name__,
235
            is_text_generation_model=is_text_generation_model(model),
236
            is_pooling_model=True,  # Can convert any model into a pooling model
237
            supports_cross_encoding=supports_cross_encoding(model),
238
239
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
240
241
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
242
            is_hybrid=is_hybrid(model),
243
            supports_transcription=supports_transcription(model))
244
245


246
class _BaseRegisteredModel(ABC):
247

248
249
250
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
251

252
253
254
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
255
256


257
258
259
260
261
262
263
264
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
265
266

    @staticmethod
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
303
    from vllm.platforms import current_platform
304
    current_platform.verify_model_arch(model_arch)
305
306
307
308
309
310
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
311
312


313
314
315
316
317
318
319
320
321
322
323
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
324
325


326
327
328
329
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
330

331
332
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
333

334
335
336
337
338
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
339
340
341
342
343
344
345
346
347
348
349
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
350
351
352
353
        if not isinstance(model_arch, str):
            msg = f"`model_arch` should be a string, not a {type(model_arch)}"
            raise TypeError(msg)

354
        if model_arch in self.models:
355
356
357
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
358
359
360
361
362
363
364
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
365

366
            model = _LazyRegisteredModel(*split_str)
367
        elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
368
            model = _RegisteredModel.from_model_cls(model_cls)
369
370
371
372
        else:
            msg = ("`model_cls` should be a string or PyTorch model class, "
                   f"not a {type(model_arch)}")
            raise TypeError(msg)
373

374
        self.models[model_arch] = model
375

376
377
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
378

379
380
381
382
383
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

384
385
386
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
387

388
389
390
391
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
392

393
        return _try_load_model_cls(model_arch, self.models[model_arch])
394

395
396
397
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
398

399
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
400

401
402
403
404
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
405
406
407
408
409
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

410
411
412
413
414
415
        normalized_arch = []
        for model in architectures:
            if model not in self.models:
                model = "TransformersModel"
            normalized_arch.append(model)
        return normalized_arch
416

417
418
419
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
420
    ) -> Tuple[_ModelInfo, str]:
421
        architectures = self._normalize_archs(architectures)
422

423
424
425
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
426
                return (model_info, arch)
427

428
        return self._raise_for_unsupported(architectures)
429

430
431
432
433
434
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
435

436
437
438
439
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
440

441
        return self._raise_for_unsupported(architectures)
442

443
444
445
446
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
447
448
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
449

450
    def is_pooling_model(
451
452
453
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
454
        model_cls, _ = self.inspect_model_cls(architectures)
455
        return model_cls.is_pooling_model
456

457
458
459
460
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
461
462
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
463

464
465
466
467
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
468
469
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
470
471
472
473
474

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
475
476
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
477

478
479
480
481
482
483
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
484

485
486
487
488
489
490
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
491

492
493
494
495
496
497
498
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

499
500
501
502
503
504
505
    def is_transcription_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

506
507

ModelRegistry = _ModelRegistry({
508
509
    model_arch:
    _LazyRegisteredModel(
510
511
512
513
514
515
516
517
518
519
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
520
521
522
523
524
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

525
        # `cloudpickle` allows pickling lambda functions directly
526
        input_bytes = cloudpickle.dumps((fn, output_filepath))
527
528
529

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
530
531
532
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
533
534
535
536
537
538
539
540
541

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

542
        with open(output_filepath, "rb") as f:
543
544
545
546
547
548
549
550
551
552
553
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
554
555
556

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
557
558
559


if __name__ == "__main__":
560
    _run()