registry.py 17.9 KB
Newer Older
1
import importlib
2
import os
3
import pickle
4
5
import subprocess
import sys
6
import tempfile
7
8
9
10
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
11

12
import cloudpickle
13
14
15
import torch.nn as nn

from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17

18
19
from .interfaces import (has_inner_state, is_attention_free,
                         supports_multimodal, supports_pp)
20
from .interfaces_base import is_embedding_model, is_text_generation_model
21
22
23

logger = init_logger(__name__)

24
# yapf: disable
25
26
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
27
28
29
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
30
31
32
33
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
34
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
35
    # ChatGLMModel supports multimodal
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
53
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
54
55
56
57
58
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
59
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
60
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
79
    # QWenLMHeadModel supports multimodal
80
81
82
83
84
85
86
87
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
    "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
88
89
90
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
91
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
92
93
94
}

_EMBEDDING_MODELS = {
95
    # [Text-only]
96
    "BertModel": ("bert", "BertEmbeddingModel"),
97
    "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
98
    "LlamaModel": ("llama", "LlamaEmbeddingModel"),
99
    "MistralModel": ("llama", "LlamaEmbeddingModel"),
100
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
101
102
    "Qwen2ForSequenceClassification": (
        "qwen2_cls", "Qwen2ForSequenceClassification"),
103
104
105
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
106
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
107
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
108
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
109
110
}

111
112
113
114
115
116
117
118
119
120
121
122
123
def add_embedding_models(base_models, embedding_models):
    with_pooler_method_models = {}
    embedding_models_name = embedding_models.keys()
    for name, (path, arch) in base_models.items():
        if arch in embedding_models_name:
            with_pooler_method_models[name] = (path, arch)
    return with_pooler_method_models

_EMBEDDING_MODELS = {
    **add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
    **_EMBEDDING_MODELS,
}

124
_MULTIMODAL_MODELS = {
125
126
127
    # [Decoder-only]
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
128
129
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
130
131
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
132
133
134
135
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
136
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
137
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
138
139
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
140
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
141
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
142
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
143
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
144
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
145
    "UltravoxModel": ("ultravox", "UltravoxModel"),
146
147
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
148
}
149
150
151
152
153

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
154
}
155
# yapf: enable
156

157
_VLLM_MODELS = {
158
    **_TEXT_GENERATION_MODELS,
159
160
    **_EMBEDDING_MODELS,
    **_MULTIMODAL_MODELS,
161
    **_SPECULATIVE_DECODING_MODELS,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


190
191
192
193
194
195
@dataclass(frozen=True)
class _ModelInfo:
    is_text_generation_model: bool
    is_embedding_model: bool
    supports_multimodal: bool
    supports_pp: bool
196
197
    has_inner_state: bool
    is_attention_free: bool
198
199

    @staticmethod
200
201
202
203
204
205
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            is_text_generation_model=is_text_generation_model(model),
            is_embedding_model=is_embedding_model(model),
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
206
207
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
208
        )
209
210


211
class _BaseRegisteredModel(ABC):
212

213
214
215
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
216

217
218
219
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
220
221


222
223
224
225
226
227
228
229
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
230
231

    @staticmethod
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
268
    if current_platform.is_rocm():
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
285
286


287
288
289
290
291
292
293
294
295
296
297
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
298
299


300
301
302
303
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
304

305
306
    def get_supported_archs(self) -> List[str]:
        return list(self.models.keys())
307

308
309
310
311
312
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
313
314
315
316
317
318
319
320
321
322
323
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
324
        if model_arch in self.models:
325
326
327
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
328
329
330
331
332
333
334
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
335

336
            model = _LazyRegisteredModel(*split_str)
337
        else:
338
            model = _RegisteredModel.from_model_cls(model_cls)
339

340
        self.models[model_arch] = model
341

342
343
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
344

345
346
347
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
348

349
350
351
352
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
353

354
        return _try_load_model_cls(model_arch, self.models[model_arch])
355

356
357
358
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
359

360
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
361

362
363
364
365
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
366
367
368
369
370
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

371
        return architectures
372

373
374
375
376
377
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> _ModelInfo:
        architectures = self._normalize_archs(architectures)
378

379
380
381
382
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
                return model_info
383

384
        return self._raise_for_unsupported(architectures)
385

386
387
388
389
390
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
391

392
393
394
395
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
396

397
        return self._raise_for_unsupported(architectures)
398

399
400
401
402
403
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_text_generation_model
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def is_embedding_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_embedding_model

    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_multimodal

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_pp

423
424
425
426
427
428
429
430
    def model_has_inner_state(self, architectures: Union[str,
                                                         List[str]]) -> bool:
        return self.inspect_model_cls(architectures).has_inner_state

    def is_attention_free_model(self, architectures: Union[str,
                                                           List[str]]) -> bool:
        return self.inspect_model_cls(architectures).is_attention_free

431
432
433
434
435
436
437
438
439
440
441
442
443

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
444
445
446
447
448
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

449
        # `cloudpickle` allows pickling lambda functions directly
450
        input_bytes = cloudpickle.dumps((fn, output_filepath))
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

467
        with open(output_filepath, "rb") as f:
468
469
470
471
472
473
474
475
476
477
478
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
479
480
481

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
482
483
484
485


if __name__ == "__main__":
    _run()