registry.py 20.7 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
import torch.nn as nn

from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22

23
from .adapters import as_embedding_model
24
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
25
26
                         supports_cross_encoding, supports_multimodal,
                         supports_pp)
27
from .interfaces_base import is_pooling_model, is_text_generation_model
28
29
30

logger = init_logger(__name__)

31
# yapf: disable
32
33
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
34
35
36
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37
38
39
40
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
41
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
42
    # ChatGLMModel supports multimodal
43
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
44
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
45
46
47
48
49
50
51
52
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
53
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
54
55
56
57
58
59
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
60
    "GritLM": ("gritlm", "GritLM"),
61
62
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
63
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
64
65
66
67
68
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
69
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
70
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
71
72
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
73
74
75
76
77
78
79
80
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
81
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
82
83
84
85
86
87
88
89
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
90
    # QWenLMHeadModel supports multimodal
91
92
93
94
95
96
97
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
98
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
99
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
100
101
102
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
103
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
104
105
106
}

_EMBEDDING_MODELS = {
107
    # [Text-only]
108
    "BertModel": ("bert", "BertEmbeddingModel"),
109
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
110
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
111
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
112
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
113
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
114
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
115
    "GritLM": ("gritlm", "GritLM"),
116
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
117
    "LlamaModel": ("llama", "LlamaForCausalLM"),
118
119
120
121
122
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
123
    "MistralModel": ("llama", "LlamaForCausalLM"),
124
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
125
126
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
127
128
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"),  # noqa: E501
129
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
130
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
131
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
132
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
133
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
134
135
}

136
137
138
139
140
141
142
143
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

144
_MULTIMODAL_MODELS = {
145
    # [Decoder-only]
146
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
147
148
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
149
150
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
151
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
152
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
153
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
154
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
155
156
157
158
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
159
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
160
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
161
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
162
163
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
164
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
165
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
166
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
167
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
168
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
169
    "UltravoxModel": ("ultravox", "UltravoxModel"),
170
171
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
172
}
173
174
175
176
177

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
178
}
179
# yapf: enable
180

181
_VLLM_MODELS = {
182
    **_TEXT_GENERATION_MODELS,
183
    **_EMBEDDING_MODELS,
184
    **_CROSS_ENCODER_MODELS,
185
    **_MULTIMODAL_MODELS,
186
    **_SPECULATIVE_DECODING_MODELS,
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


215
216
@dataclass(frozen=True)
class _ModelInfo:
217
    architecture: str
218
    is_text_generation_model: bool
219
    is_pooling_model: bool
220
    supports_cross_encoding: bool
221
222
    supports_multimodal: bool
    supports_pp: bool
223
224
    has_inner_state: bool
    is_attention_free: bool
225
    is_hybrid: bool
226
227

    @staticmethod
228
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
229
230
        is_pooling_model_ = is_pooling_model(model)
        if not is_pooling_model_:
231
232
233
234
235
            try:
                as_embedding_model(model)
            except Exception:
                pass
            else:
236
                is_pooling_model_ = True
237

238
        return _ModelInfo(
239
            architecture=model.__name__,
240
            is_text_generation_model=is_text_generation_model(model),
241
            is_pooling_model=is_pooling_model_,
242
            supports_cross_encoding=supports_cross_encoding(model),
243
244
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
245
246
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
247
            is_hybrid=is_hybrid(model),
248
        )
249
250


251
class _BaseRegisteredModel(ABC):
252

253
254
255
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
256

257
258
259
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
260
261


262
263
264
265
266
267
268
269
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
270
271

    @staticmethod
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
308
    if current_platform.is_rocm():
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
325
326


327
328
329
330
331
332
333
334
335
336
337
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
338
339


340
341
342
343
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
344

345
346
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
347

348
349
350
351
352
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
353
354
355
356
357
358
359
360
361
362
363
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
364
        if model_arch in self.models:
365
366
367
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
368
369
370
371
372
373
374
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
375

376
            model = _LazyRegisteredModel(*split_str)
377
        else:
378
            model = _RegisteredModel.from_model_cls(model_cls)
379

380
        self.models[model_arch] = model
381

382
383
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
384

385
386
387
388
389
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

390
391
392
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
393

394
395
396
397
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
398

399
        return _try_load_model_cls(model_arch, self.models[model_arch])
400

401
402
403
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
404

405
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
406

407
408
409
410
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
411
412
413
414
415
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

416
        return architectures
417

418
419
420
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
421
    ) -> Tuple[_ModelInfo, str]:
422
        architectures = self._normalize_archs(architectures)
423

424
425
426
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
427
                return (model_info, arch)
428

429
        return self._raise_for_unsupported(architectures)
430

431
432
433
434
435
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
436

437
438
439
440
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
441

442
        return self._raise_for_unsupported(architectures)
443

444
445
446
447
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
448
449
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
450

451
    def is_pooling_model(
452
453
454
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
455
        model_cls, _ = self.inspect_model_cls(architectures)
456
        return model_cls.is_pooling_model
457

458
459
460
461
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
462
463
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
464

465
466
467
468
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
469
470
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
471
472
473
474
475

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
476
477
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
478

479
480
481
482
483
484
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
485

486
487
488
489
490
491
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
492

493
494
495
496
497
498
499
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

500
501
502
503
504
505
506
507
508
509
510
511
512

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
513
514
515
516
517
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

518
        # `cloudpickle` allows pickling lambda functions directly
519
        input_bytes = cloudpickle.dumps((fn, output_filepath))
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

536
        with open(output_filepath, "rb") as f:
537
538
539
540
541
542
543
544
545
546
547
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
548
549
550

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
551
552
553


if __name__ == "__main__":
554
    _run()