"vllm/model_executor/models/deepseek_ocr2.py" did not exist on "c46b932df2b801ba0a6452e436268f086029d82b"
registry.py 20.4 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
import torch.nn as nn

from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22

23
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
24
25
                         supports_cross_encoding, supports_multimodal,
                         supports_pp)
26
from .interfaces_base import is_text_generation_model
27
28
29

logger = init_logger(__name__)

30
# yapf: disable
31
32
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
33
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
36
37
38
39
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
40
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
41
    # ChatGLMModel supports multimodal
42
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
43
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
44
45
46
47
48
49
50
51
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
52
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
53
54
55
56
57
58
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
59
    "GritLM": ("gritlm", "GritLM"),
60
61
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
62
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
63
64
65
66
67
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
68
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
69
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
70
71
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
72
73
74
75
76
77
78
79
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
80
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
81
82
83
84
85
86
87
88
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
89
    # QWenLMHeadModel supports multimodal
90
91
92
93
94
95
96
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
97
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
98
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
99
100
101
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
102
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
103
104
105
}

_EMBEDDING_MODELS = {
106
    # [Text-only]
107
    "BertModel": ("bert", "BertEmbeddingModel"),
108
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
109
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
110
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
111
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
112
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
114
    "GritLM": ("gritlm", "GritLM"),
115
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
116
    "LlamaModel": ("llama", "LlamaForCausalLM"),
117
118
119
120
121
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
122
    "MistralModel": ("llama", "LlamaForCausalLM"),
123
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
124
125
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
126
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
127
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
128
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
129
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
130
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
131
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
132
133
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
134
135
}

136
137
138
139
140
141
142
143
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

144
_MULTIMODAL_MODELS = {
145
    # [Decoder-only]
146
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
147
148
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
149
150
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
151
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
152
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
153
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
154
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
155
156
157
158
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
159
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
160
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
161
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
162
163
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
164
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
165
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
166
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
167
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
168
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
169
    "UltravoxModel": ("ultravox", "UltravoxModel"),
170
171
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
172
}
173
174
175
176
177

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
178
}
179
# yapf: enable
180

181
_VLLM_MODELS = {
182
    **_TEXT_GENERATION_MODELS,
183
    **_EMBEDDING_MODELS,
184
    **_CROSS_ENCODER_MODELS,
185
    **_MULTIMODAL_MODELS,
186
    **_SPECULATIVE_DECODING_MODELS,
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


215
216
@dataclass(frozen=True)
class _ModelInfo:
217
    architecture: str
218
    is_text_generation_model: bool
219
    is_pooling_model: bool
220
    supports_cross_encoding: bool
221
222
    supports_multimodal: bool
    supports_pp: bool
223
224
    has_inner_state: bool
    is_attention_free: bool
225
    is_hybrid: bool
226
227

    @staticmethod
228
229
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
230
            architecture=model.__name__,
231
            is_text_generation_model=is_text_generation_model(model),
232
            is_pooling_model=True,  # Can convert any model into a pooling model
233
            supports_cross_encoding=supports_cross_encoding(model),
234
235
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
236
237
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
238
            is_hybrid=is_hybrid(model),
239
        )
240
241


242
class _BaseRegisteredModel(ABC):
243

244
245
246
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
247

248
249
250
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
251
252


253
254
255
256
257
258
259
260
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
261
262

    @staticmethod
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
299
    if current_platform.is_rocm():
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
316
317


318
319
320
321
322
323
324
325
326
327
328
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
329
330


331
332
333
334
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
335

336
337
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
338

339
340
341
342
343
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
344
345
346
347
348
349
350
351
352
353
354
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
355
        if model_arch in self.models:
356
357
358
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
359
360
361
362
363
364
365
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
366

367
            model = _LazyRegisteredModel(*split_str)
368
        else:
369
            model = _RegisteredModel.from_model_cls(model_cls)
370

371
        self.models[model_arch] = model
372

373
374
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
375

376
377
378
379
380
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

381
382
383
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
384

385
386
387
388
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
389

390
        return _try_load_model_cls(model_arch, self.models[model_arch])
391

392
393
394
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
395

396
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
397

398
399
400
401
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
402
403
404
405
406
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

407
        return architectures
408

409
410
411
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
412
    ) -> Tuple[_ModelInfo, str]:
413
        architectures = self._normalize_archs(architectures)
414

415
416
417
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
418
                return (model_info, arch)
419

420
        return self._raise_for_unsupported(architectures)
421

422
423
424
425
426
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
427

428
429
430
431
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
432

433
        return self._raise_for_unsupported(architectures)
434

435
436
437
438
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
439
440
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
441

442
    def is_pooling_model(
443
444
445
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
446
        model_cls, _ = self.inspect_model_cls(architectures)
447
        return model_cls.is_pooling_model
448

449
450
451
452
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
453
454
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
455

456
457
458
459
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
460
461
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
462
463
464
465
466

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
467
468
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
469

470
471
472
473
474
475
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
476

477
478
479
480
481
482
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
483

484
485
486
487
488
489
490
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

491
492
493
494
495
496
497
498
499
500
501
502
503

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
504
505
506
507
508
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

509
        # `cloudpickle` allows pickling lambda functions directly
510
        input_bytes = cloudpickle.dumps((fn, output_filepath))
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

527
        with open(output_filepath, "rb") as f:
528
529
530
531
532
533
534
535
536
537
538
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
539
540
541

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
542
543
544


if __name__ == "__main__":
545
    _run()