registry.py 20.5 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
import torch.nn as nn

from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22

23
from .adapters import as_embedding_model
24
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
25
26
                         supports_cross_encoding, supports_multimodal,
                         supports_pp)
27
from .interfaces_base import is_pooling_model, is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
38
39
40
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
41
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
42
    # ChatGLMModel supports multimodal
43
44
45
46
47
48
49
50
51
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
52
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
53
54
55
56
57
58
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
59
    "GritLM": ("gritlm", "GritLM"),
60
61
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
62
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
63
64
65
66
67
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
68
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
69
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
70
71
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
72
73
74
75
76
77
78
79
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
80
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
81
82
83
84
85
86
87
88
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
89
    # QWenLMHeadModel supports multimodal
90
91
92
93
94
95
96
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
97
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
98
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
99
100
101
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
102
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
103
104
105
}

_EMBEDDING_MODELS = {
106
    # [Text-only]
107
    "BertModel": ("bert", "BertEmbeddingModel"),
108
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
109
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
110
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
111
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
112
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
113
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
114
    "GritLM": ("gritlm", "GritLM"),
115
    "LlamaModel": ("llama", "LlamaForCausalLM"),
116
117
118
119
120
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
121
    "MistralModel": ("llama", "LlamaForCausalLM"),
122
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
123
124
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
125
126
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"),  # noqa: E501
127
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
128
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
129
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
130
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
131
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
132
133
}

134
135
136
137
138
139
140
141
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

142
_MULTIMODAL_MODELS = {
143
    # [Decoder-only]
144
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
145
146
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
147
148
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
149
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
150
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
151
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
152
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
153
154
155
156
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
157
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
158
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
159
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
160
161
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
162
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
163
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
164
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
165
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
166
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
167
    "UltravoxModel": ("ultravox", "UltravoxModel"),
168
169
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
170
}
171
172
173
174
175

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
176
}
177
# yapf: enable
178

179
_VLLM_MODELS = {
180
    **_TEXT_GENERATION_MODELS,
181
    **_EMBEDDING_MODELS,
182
    **_CROSS_ENCODER_MODELS,
183
    **_MULTIMODAL_MODELS,
184
    **_SPECULATIVE_DECODING_MODELS,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


213
214
@dataclass(frozen=True)
class _ModelInfo:
215
    architecture: str
216
    is_text_generation_model: bool
217
    is_pooling_model: bool
218
    supports_cross_encoding: bool
219
220
    supports_multimodal: bool
    supports_pp: bool
221
222
    has_inner_state: bool
    is_attention_free: bool
223
    is_hybrid: bool
224
225

    @staticmethod
226
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
227
228
        is_pooling_model_ = is_pooling_model(model)
        if not is_pooling_model_:
229
230
231
232
233
            try:
                as_embedding_model(model)
            except Exception:
                pass
            else:
234
                is_pooling_model_ = True
235

236
        return _ModelInfo(
237
            architecture=model.__name__,
238
            is_text_generation_model=is_text_generation_model(model),
239
            is_pooling_model=is_pooling_model_,
240
            supports_cross_encoding=supports_cross_encoding(model),
241
242
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
243
244
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
245
            is_hybrid=is_hybrid(model),
246
        )
247
248


249
class _BaseRegisteredModel(ABC):
250

251
252
253
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
254

255
256
257
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
258
259


260
261
262
263
264
265
266
267
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
268
269

    @staticmethod
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
306
    if current_platform.is_rocm():
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
323
324


325
326
327
328
329
330
331
332
333
334
335
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
336
337


338
339
340
341
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
342

343
344
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
345

346
347
348
349
350
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
351
352
353
354
355
356
357
358
359
360
361
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
362
        if model_arch in self.models:
363
364
365
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
366
367
368
369
370
371
372
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
373

374
            model = _LazyRegisteredModel(*split_str)
375
        else:
376
            model = _RegisteredModel.from_model_cls(model_cls)
377

378
        self.models[model_arch] = model
379

380
381
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
382

383
384
385
386
387
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

388
389
390
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
391

392
393
394
395
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
396

397
        return _try_load_model_cls(model_arch, self.models[model_arch])
398

399
400
401
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
402

403
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
404

405
406
407
408
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
409
410
411
412
413
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

414
        return architectures
415

416
417
418
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
419
    ) -> Tuple[_ModelInfo, str]:
420
        architectures = self._normalize_archs(architectures)
421

422
423
424
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
425
                return (model_info, arch)
426

427
        return self._raise_for_unsupported(architectures)
428

429
430
431
432
433
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
434

435
436
437
438
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
439

440
        return self._raise_for_unsupported(architectures)
441

442
443
444
445
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
446
447
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
448

449
    def is_pooling_model(
450
451
452
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
453
        model_cls, _ = self.inspect_model_cls(architectures)
454
        return model_cls.is_pooling_model
455

456
457
458
459
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
460
461
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
462

463
464
465
466
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
467
468
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
469
470
471
472
473

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
474
475
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
476

477
478
479
480
481
482
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
483

484
485
486
487
488
489
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
490

491
492
493
494
495
496
497
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

498
499
500
501
502
503
504
505
506
507
508
509
510

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
511
512
513
514
515
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

516
        # `cloudpickle` allows pickling lambda functions directly
517
        input_bytes = cloudpickle.dumps((fn, output_filepath))
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

534
        with open(output_filepath, "rb") as f:
535
536
537
538
539
540
541
542
543
544
545
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
546
547
548

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
549
550
551


if __name__ == "__main__":
552
    _run()