registry.py 19.5 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
import torch.nn as nn

from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22

23
from .interfaces import (has_inner_state, is_attention_free,
24
25
                         supports_cross_encoding, supports_multimodal,
                         supports_pp)
26
from .interfaces_base import is_embedding_model, is_text_generation_model
27
28
29

logger = init_logger(__name__)

30
# yapf: disable
31
32
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
33
34
35
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
36
37
38
39
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
40
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
41
    # ChatGLMModel supports multimodal
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
59
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
60
61
62
63
64
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
65
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
66
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
67
68
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
69
70
71
72
73
74
75
76
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
77
    "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
78
79
80
81
82
83
84
85
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
86
    # QWenLMHeadModel supports multimodal
87
88
89
90
91
92
93
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
94
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
95
    "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
96
97
98
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
99
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
100
101
102
}

_EMBEDDING_MODELS = {
103
    # [Text-only]
104
    "BertModel": ("bert", "BertEmbeddingModel"),
105
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
106
    "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
107
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
108
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
109
    "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
110
    "LlamaModel": ("llama", "LlamaEmbeddingModel"),
111
112
113
114
115
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
116
    "MistralModel": ("llama", "LlamaEmbeddingModel"),
117
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
118
119
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
120
121
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"),  # noqa: E501
122
    "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
123
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
124
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
125
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
126
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
127
128
}

129
130
131
132
133
134
135
136
_CROSS_ENCODER_MODELS = {
    "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
    "RobertaForSequenceClassification": ("roberta",
                                         "RobertaForSequenceClassification"),
    "XLMRobertaForSequenceClassification": ("roberta",
                                            "RobertaForSequenceClassification"),
}

137
_MULTIMODAL_MODELS = {
138
    # [Decoder-only]
139
    "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
140
141
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
142
143
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
144
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
145
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
146
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
147
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
148
149
150
151
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
152
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
153
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
154
155
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
156
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
157
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
158
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
159
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
160
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
161
    "UltravoxModel": ("ultravox", "UltravoxModel"),
162
163
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
164
}
165
166
167
168
169

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
170
}
171
# yapf: enable
172

173
_VLLM_MODELS = {
174
    **_TEXT_GENERATION_MODELS,
175
    **_EMBEDDING_MODELS,
176
    **_CROSS_ENCODER_MODELS,
177
    **_MULTIMODAL_MODELS,
178
    **_SPECULATIVE_DECODING_MODELS,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


207
208
209
210
@dataclass(frozen=True)
class _ModelInfo:
    is_text_generation_model: bool
    is_embedding_model: bool
211
    supports_cross_encoding: bool
212
213
    supports_multimodal: bool
    supports_pp: bool
214
215
    has_inner_state: bool
    is_attention_free: bool
216
217

    @staticmethod
218
219
220
221
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            is_text_generation_model=is_text_generation_model(model),
            is_embedding_model=is_embedding_model(model),
222
            supports_cross_encoding=supports_cross_encoding(model),
223
224
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
225
226
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
227
        )
228
229


230
class _BaseRegisteredModel(ABC):
231

232
233
234
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
235

236
237
238
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
239
240


241
242
243
244
245
246
247
248
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
249
250

    @staticmethod
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
287
    if current_platform.is_rocm():
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
304
305


306
307
308
309
310
311
312
313
314
315
316
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
317
318


319
320
321
322
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
323

324
325
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
326

327
328
329
330
331
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
332
333
334
335
336
337
338
339
340
341
342
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
343
        if model_arch in self.models:
344
345
346
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
347
348
349
350
351
352
353
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
354

355
            model = _LazyRegisteredModel(*split_str)
356
        else:
357
            model = _RegisteredModel.from_model_cls(model_cls)
358

359
        self.models[model_arch] = model
360

361
362
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
363

364
365
366
367
368
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

369
370
371
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
372

373
374
375
376
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
377

378
        return _try_load_model_cls(model_arch, self.models[model_arch])
379

380
381
382
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
383

384
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
385

386
387
388
389
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
390
391
392
393
394
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

395
        return architectures
396

397
398
399
400
401
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> _ModelInfo:
        architectures = self._normalize_archs(architectures)
402

403
404
405
406
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
                return model_info
407

408
        return self._raise_for_unsupported(architectures)
409

410
411
412
413
414
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
415

416
417
418
419
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
420

421
        return self._raise_for_unsupported(architectures)
422

423
424
425
426
427
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_text_generation_model
428

429
430
431
432
433
434
    def is_embedding_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_embedding_model

435
436
437
438
439
440
    def is_cross_encoder_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_cross_encoding

441
442
443
444
445
446
447
448
449
450
451
452
    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_multimodal

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_pp

453
454
455
456
457
458
459
460
    def model_has_inner_state(self, architectures: Union[str,
                                                         List[str]]) -> bool:
        return self.inspect_model_cls(architectures).has_inner_state

    def is_attention_free_model(self, architectures: Union[str,
                                                           List[str]]) -> bool:
        return self.inspect_model_cls(architectures).is_attention_free

461
462
463
464
465
466
467
468
469
470
471
472
473

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
474
475
476
477
478
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

479
        # `cloudpickle` allows pickling lambda functions directly
480
        input_bytes = cloudpickle.dumps((fn, output_filepath))
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

497
        with open(output_filepath, "rb") as f:
498
499
500
501
502
503
504
505
506
507
508
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
509
510
511

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
512
513
514


if __name__ == "__main__":
515
    _run()