registry.py 19.4 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
21
import torch.nn as nn

from vllm.logger import init_logger

22
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
23
24
                         supports_cross_encoding, supports_multimodal,
                         supports_pp)
25
from .interfaces_base import is_text_generation_model
26
27
28

logger = init_logger(__name__)

29
# yapf: disable
30
31
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
32
33
34
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
35
36
37
38
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
39
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
40
    # ChatGLMModel supports multimodal
41
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
42
    "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
43
44
45
46
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
Simon Mo's avatar
Simon Mo committed
47
    "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
48
49
50
51
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
52
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
53
54
55
56
57
58
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
59
    "GritLM": ("gritlm", "GritLM"),
60
61
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
62
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
63
    "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
64
65
66
67
68
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
69
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
70
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
71
72
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
73
74
75
76
77
78
79
80
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
81
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
82
83
84
85
86
87
88
89
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
90
    # QWenLMHeadModel supports multimodal
91
92
93
94
95
96
97
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
98
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
99
    "XverseForCausalLM": ("llama", "LlamaForCausalLM"),
100
101
102
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
103
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
104
105
106
}

_EMBEDDING_MODELS = {
107
    # [Text-only]
108
    "BertModel": ("bert", "BertEmbeddingModel"),
109
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
110
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
111
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
112
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
113
    "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
114
    "GlmForCausalLM": ("glm", "GlmForCausalLM"),
115
    "GritLM": ("gritlm", "GritLM"),
116
    "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
117
    "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"),  # noqa: E501
118
    "LlamaModel": ("llama", "LlamaForCausalLM"),
119
120
121
122
123
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
124
    "MistralModel": ("llama", "LlamaForCausalLM"),
125
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
126
127
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
128
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
129
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
130
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
131
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
132
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
133
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
134
135
    # [Auto-converted (see adapters.py)]
    "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
136
137
}

138
139
140
141
142
143
144
145
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

146
_MULTIMODAL_MODELS = {
147
    # [Decoder-only]
148
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
149
150
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
151
152
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
153
    "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
154
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
155
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
156
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
157
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
158
159
160
161
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
162
    "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"),  # noqa: E501
163
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
164
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
165
166
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
167
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
168
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
169
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
170
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
171
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
172
    "UltravoxModel": ("ultravox", "UltravoxModel"),
173
174
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
175
    "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"),  # noqa: E501
176
}
177
178
179
180
181

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
182
}
183
# yapf: enable
184

185
_VLLM_MODELS = {
186
    **_TEXT_GENERATION_MODELS,
187
    **_EMBEDDING_MODELS,
188
    **_CROSS_ENCODER_MODELS,
189
    **_MULTIMODAL_MODELS,
190
    **_SPECULATIVE_DECODING_MODELS,
191
192
193
}


194
195
@dataclass(frozen=True)
class _ModelInfo:
196
    architecture: str
197
    is_text_generation_model: bool
198
    is_pooling_model: bool
199
    supports_cross_encoding: bool
200
201
    supports_multimodal: bool
    supports_pp: bool
202
203
    has_inner_state: bool
    is_attention_free: bool
204
    is_hybrid: bool
205
206

    @staticmethod
207
208
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
209
            architecture=model.__name__,
210
            is_text_generation_model=is_text_generation_model(model),
211
            is_pooling_model=True,  # Can convert any model into a pooling model
212
            supports_cross_encoding=supports_cross_encoding(model),
213
214
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
215
216
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
217
            is_hybrid=is_hybrid(model),
218
        )
219
220


221
class _BaseRegisteredModel(ABC):
222

223
224
225
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
226

227
228
229
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
230
231


232
233
234
235
236
237
238
239
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
240
241

    @staticmethod
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
278
    from vllm.platforms import current_platform
279
    current_platform.verify_model_arch(model_arch)
280
281
282
283
284
285
    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
286
287


288
289
290
291
292
293
294
295
296
297
298
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
299
300


301
302
303
304
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
305

306
307
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
308

309
310
311
312
313
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
314
315
316
317
318
319
320
321
322
323
324
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
325
        if model_arch in self.models:
326
327
328
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
329
330
331
332
333
334
335
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
336

337
            model = _LazyRegisteredModel(*split_str)
338
        else:
339
            model = _RegisteredModel.from_model_cls(model_cls)
340

341
        self.models[model_arch] = model
342

343
344
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
345

346
347
348
349
350
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

351
352
353
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
354

355
356
357
358
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
359

360
        return _try_load_model_cls(model_arch, self.models[model_arch])
361

362
363
364
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
365

366
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
367

368
369
370
371
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
372
373
374
375
376
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

377
        return architectures
378

379
380
381
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
382
    ) -> Tuple[_ModelInfo, str]:
383
        architectures = self._normalize_archs(architectures)
384

385
386
387
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
388
                return (model_info, arch)
389

390
        return self._raise_for_unsupported(architectures)
391

392
393
394
395
396
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
397

398
399
400
401
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
402

403
        return self._raise_for_unsupported(architectures)
404

405
406
407
408
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
409
410
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_text_generation_model
411

412
    def is_pooling_model(
413
414
415
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
416
        model_cls, _ = self.inspect_model_cls(architectures)
417
        return model_cls.is_pooling_model
418

419
420
421
422
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
423
424
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_cross_encoding
425

426
427
428
429
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
430
431
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_multimodal
432
433
434
435
436

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
437
438
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.supports_pp
439

440
441
442
443
444
445
    def model_has_inner_state(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.has_inner_state
446

447
448
449
450
451
452
    def is_attention_free_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_attention_free
453

454
455
456
457
458
459
460
    def is_hybrid_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        model_cls, _ = self.inspect_model_cls(architectures)
        return model_cls.is_hybrid

461
462
463
464
465
466
467
468
469
470
471
472
473

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
474
475
476
477
478
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

479
        # `cloudpickle` allows pickling lambda functions directly
480
        input_bytes = cloudpickle.dumps((fn, output_filepath))
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

497
        with open(output_filepath, "rb") as f:
498
499
500
501
502
503
504
505
506
507
508
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
509
510
511

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
512
513
514


if __name__ == "__main__":
515
    _run()