registry.py 21 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
6
import importlib
7
import os
8
import pickle
9
10
import subprocess
import sys
11
import tempfile
12
13
14
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
15
16
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
17

18
import cloudpickle
19
20
21
22
import torch.nn as nn

from vllm.logger import init_logger

23
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
24
                         supports_cross_encoding, supports_multimodal,
25
                         supports_pp, supports_transcription)
26
from .interfaces_base import is_text_generation_model
27
28
29

logger = init_logger(__name__)

30
# yapf: disable
31
32
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
33
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
36
37
38
39
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
40
    "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
41
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
42
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
43
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
44
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
45
46
47
48
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
49
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
50
51
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
52
    "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
53
54
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
55
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
56
57
58
59
60
61
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
62
    "GritLM": ("gritlm", "GritLM"),
63
64
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
65
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
66
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
67
68
69
70
71
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
72
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
73
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
74
75
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
76
77
78
79
80
81
82
83
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
84
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
85
86
87
88
89
90
91
92
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
93
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
94
95
96
97
98
99
100
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
101
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
102
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
103
104
105
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
106
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
107
108
109
}

_EMBEDDING_MODELS = {
110
    # [Text-only]
111
    "BertModel": ("bert", "BertEmbeddingModel"),
112
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
113
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
114
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
115
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
116
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
117
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
118
    "GritLM": ("gritlm", "GritLM"),
119
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
120
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
121
    "LlamaModel": ("llama", "LlamaForCausalLM"),
122
123
124
125
126
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
127
    "MistralModel": ("llama", "LlamaForCausalLM"),
128
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
129
130
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
131
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
132
    "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
133
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
134
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
135
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
136
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
137
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
138
139
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
140
141
142
143
    # Technically PrithviGeoSpatialMAE is a model that works on images, both in
    # input and output. I am adding it here because it piggy-backs on embedding
    # models for the time being.
    "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
144
145
}

146
147
148
149
150
151
152
153
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

154
_MULTIMODAL_MODELS = {
155
    # [Decoder-only]
156
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
157
158
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
159
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
160
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
161
    "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
162
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
163
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
164
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
165
166
167
168
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
169
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
170
    "MiniCPMO": ("minicpmo", "MiniCPMO"),
171
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
172
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
173
174
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
175
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
176
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
177
    "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"),  # noqa: E501
178
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
Roger Wang's avatar
Roger Wang committed
179
    "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),  # noqa: E501
180
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
181
    "UltravoxModel": ("ultravox", "UltravoxModel"),
182
183
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
184
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
185
}
186
187
188
189
190

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
191
}
192
193
194
195

_FALLBACK_MODEL = {
    "TransformersModel": ("transformers", "TransformersModel"),
}
196
# yapf: enable
197

198
_VLLM_MODELS = {
199
    **_TEXT_GENERATION_MODELS,
200
    **_EMBEDDING_MODELS,
201
    **_CROSS_ENCODER_MODELS,
202
    **_MULTIMODAL_MODELS,
203
    **_SPECULATIVE_DECODING_MODELS,
204
    **_FALLBACK_MODEL,
205
206
}

207
208
209
210
211
212
213
214
# This variable is used as the args for subprocess.run(). We
# can modify  this variable to alter the args if needed. e.g.
# when we use par format to pack things together, sys.executable
# might not be the target we want to run.
_SUBPROCESS_COMMAND = [
    sys.executable, "-m", "vllm.model_executor.models.registry"
]

215

216
217
@dataclass(frozen=True)
class _ModelInfo:
218
    architecture: str
219
    is_text_generation_model: bool
220
    is_pooling_model: bool
221
    supports_cross_encoding: bool
222
223
    supports_multimodal: bool
    supports_pp: bool
224
225
    has_inner_state: bool
    is_attention_free: bool
226
    is_hybrid: bool
227
    supports_transcription: bool
228
229

    @staticmethod
230
231
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
232
            architecture=model.__name__,
233
            is_text_generation_model=is_text_generation_model(model),
234
            is_pooling_model=True,  # Can convert any model into a pooling model
235
            supports_cross_encoding=supports_cross_encoding(model),
236
237
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
238
239
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
240
            is_hybrid=is_hybrid(model),
241
            supports_transcription=supports_transcription(model))
242
243


244
class _BaseRegisteredModel(ABC):
245

246
247
248
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
249

250
251
252
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
253
254


255
256
257
258
259
260
261
262
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
263
264

    @staticmethod
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
301
    from vllm.platforms import current_platform
302
    current_platform.verify_model_arch(model_arch)
303
304
305
306
307
308
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
309
310


311
312
313
314
315
316
317
318
319
320
321
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
322
323


324
325
326
327
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
328

329
330
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
331

332
333
334
335
336
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
337
338
339
340
341
342
343
344
345
346
347
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
348
        if model_arch in self.models:
349
350
351
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
352
353
354
355
356
357
358
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
359

360
            model = _LazyRegisteredModel(*split_str)
361
        else:
362
            model = _RegisteredModel.from_model_cls(model_cls)
363

364
        self.models[model_arch] = model
365

366
367
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
368

369
370
371
372
373
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

374
375
376
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
377

378
379
380
381
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
382

383
        return _try_load_model_cls(model_arch, self.models[model_arch])
384

385
386
387
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
388

389
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
390

391
392
393
394
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
395
396
397
398
399
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

400
401
402
403
404
405
        normalized_arch = []
        for model in architectures:
            if model not in self.models:
                model = "TransformersModel"
            normalized_arch.append(model)
        return normalized_arch
406

407
408
409
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
410
    ) -> Tuple[_ModelInfo, str]:
411
        architectures = self._normalize_archs(architectures)
412

413
414
415
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
416
                return (model_info, arch)
417

418
        return self._raise_for_unsupported(architectures)
419

420
421
422
423
424
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
425

426
427
428
429
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
430

431
        return self._raise_for_unsupported(architectures)
432

433
434
435
436
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
437
438
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
439

440
    def is_pooling_model(
441
442
443
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
444
        model_cls, _ = self.inspect_model_cls(architectures)
445
        return model_cls.is_pooling_model
446

447
448
449
450
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
451
452
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
453

454
455
456
457
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
458
459
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
460
461
462
463
464

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
465
466
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
467

468
469
470
471
472
473
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
474

475
476
477
478
479
480
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
481

482
483
484
485
486
487
488
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

489
490
491
492
493
494
495
    def is_transcription_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_transcription

496
497

ModelRegistry = _ModelRegistry({
498
499
    model_arch:
    _LazyRegisteredModel(
500
501
502
503
504
505
506
507
508
509
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
510
511
512
513
514
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

515
        # `cloudpickle` allows pickling lambda functions directly
516
        input_bytes = cloudpickle.dumps((fn, output_filepath))
517
518
519

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
520
521
522
        returned = subprocess.run(_SUBPROCESS_COMMAND,
                                  input=input_bytes,
                                  capture_output=True)
523
524
525
526
527
528
529
530
531

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

532
        with open(output_filepath, "rb") as f:
533
534
535
536
537
538
539
540
541
542
543
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
544
545
546

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
547
548
549


if __name__ == "__main__":
550
    _run()