registry.py 18.4 KB
Newer Older
1
2
3
4
"""
Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
5
import importlib
6
import os
7
import pickle
8
9
import subprocess
import sys
10
import tempfile
11
12
13
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
14
15
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
                    TypeVar, Union)
16

17
import cloudpickle
18
19
20
import torch.nn as nn

from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22

23
24
from .interfaces import (has_inner_state, is_attention_free,
                         supports_multimodal, supports_pp)
25
from .interfaces_base import is_embedding_model, is_text_generation_model
26
27
28

logger = init_logger(__name__)

29
# yapf: disable
30
31
_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
32
33
34
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
35
36
37
38
    # baichuan-7b, upper case 'C' in the class name
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
    # baichuan-13b, lower case 'c' in the class name
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
39
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
40
    # ChatGLMModel supports multimodal
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
58
    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
59
60
61
62
63
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
64
    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
65
    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
66
67
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
84
    # QWenLMHeadModel supports multimodal
85
86
87
88
89
90
91
92
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
    "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
93
94
95
    # [Encoder-decoder]
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
96
    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
97
98
99
}

_EMBEDDING_MODELS = {
100
    # [Text-only]
101
    "BertModel": ("bert", "BertEmbeddingModel"),
102
103
    "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
    "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
104
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
105
    "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
106
    "LlamaModel": ("llama", "LlamaEmbeddingModel"),
107
108
109
110
111
    **{
        # Multiple models share the same architecture, so we include them all
        k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
        if arch == "LlamaForCausalLM"
    },
112
    "MistralModel": ("llama", "LlamaEmbeddingModel"),
113
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
114
115
    "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
116
117
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
    "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"),  # noqa: E501
118
    # [Multimodal]
Cyrus Leung's avatar
Cyrus Leung committed
119
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
120
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
121
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
122
123
124
}

_MULTIMODAL_MODELS = {
125
126
127
    # [Decoder-only]
    "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"),  # noqa: E501
128
129
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
130
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
131
    "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
132
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
133
    "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
134
135
136
137
    "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
    "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),  # noqa: E501
    "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),  # noqa: E501
138
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
139
    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
140
141
    "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
    "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"),  # noqa: E501
142
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
143
    "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
144
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
145
    "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
146
    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
147
    "UltravoxModel": ("ultravox", "UltravoxModel"),
148
149
    # [Encoder-decoder]
    "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
150
}
151
152
153
154
155

_SPECULATIVE_DECODING_MODELS = {
    "EAGLEModel": ("eagle", "EAGLE"),
    "MedusaModel": ("medusa", "Medusa"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
156
}
157
# yapf: enable
158

159
_VLLM_MODELS = {
160
    **_TEXT_GENERATION_MODELS,
161
162
    **_EMBEDDING_MODELS,
    **_MULTIMODAL_MODELS,
163
    **_SPECULATIVE_DECODING_MODELS,
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


192
193
194
195
196
197
@dataclass(frozen=True)
class _ModelInfo:
    is_text_generation_model: bool
    is_embedding_model: bool
    supports_multimodal: bool
    supports_pp: bool
198
199
    has_inner_state: bool
    is_attention_free: bool
200
201

    @staticmethod
202
203
204
205
206
207
    def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
        return _ModelInfo(
            is_text_generation_model=is_text_generation_model(model),
            is_embedding_model=is_embedding_model(model),
            supports_multimodal=supports_multimodal(model),
            supports_pp=supports_pp(model),
208
209
            has_inner_state=has_inner_state(model),
            is_attention_free=is_attention_free(model),
210
        )
211
212


213
class _BaseRegisteredModel(ABC):
214

215
216
217
    @abstractmethod
    def inspect_model_cls(self) -> _ModelInfo:
        raise NotImplementedError
218

219
220
221
    @abstractmethod
    def load_model_cls(self) -> Type[nn.Module]:
        raise NotImplementedError
222
223


224
225
226
227
228
229
230
231
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has already been imported in the main process.
    """

    interfaces: _ModelInfo
    model_cls: Type[nn.Module]
232
233

    @staticmethod
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def from_model_cls(model_cls: Type[nn.Module]):
        return _RegisteredModel(
            interfaces=_ModelInfo.from_model_cls(model_cls),
            model_cls=model_cls,
        )

    def inspect_model_cls(self) -> _ModelInfo:
        return self.interfaces

    def load_model_cls(self) -> Type[nn.Module]:
        return self.model_cls


@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
    """
    Represents a model that has not been imported in the main process.
    """
    module_name: str
    class_name: str

    # Performed in another process to avoid initializing CUDA
    def inspect_model_cls(self) -> _ModelInfo:
        return _run_in_subprocess(
            lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

    def load_model_cls(self) -> Type[nn.Module]:
        mod = importlib.import_module(self.module_name)
        return getattr(mod, self.class_name)


@lru_cache(maxsize=128)
def _try_load_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
270
    if current_platform.is_rocm():
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

    try:
        return model.load_model_cls()
    except Exception:
        logger.exception("Error in loading model architecture '%s'",
                         model_arch)
        return None
287
288


289
290
291
292
293
294
295
296
297
298
299
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
    model_arch: str,
    model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
    try:
        return model.inspect_model_cls()
    except Exception:
        logger.exception("Error in inspecting model architecture '%s'",
                         model_arch)
        return None
300
301


302
303
304
305
@dataclass
class _ModelRegistry:
    # Keyed by model_arch
    models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
306

307
308
    def get_supported_archs(self) -> AbstractSet[str]:
        return self.models.keys()
309

310
311
312
313
314
    def register_model(
        self,
        model_arch: str,
        model_cls: Union[Type[nn.Module], str],
    ) -> None:
315
316
317
318
319
320
321
322
323
324
325
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
326
        if model_arch in self.models:
327
328
329
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
330
331
332
333
334
335
336
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
337

338
            model = _LazyRegisteredModel(*split_str)
339
        else:
340
            model = _RegisteredModel.from_model_cls(model_cls)
341

342
        self.models[model_arch] = model
343

344
345
    def _raise_for_unsupported(self, architectures: List[str]):
        all_supported_archs = self.get_supported_archs()
346

347
348
349
350
351
        if any(arch in all_supported_archs for arch in architectures):
            raise ValueError(
                f"Model architectures {architectures} failed "
                "to be inspected. Please check the logs for more details.")

352
353
354
        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {all_supported_archs}")
355

356
357
358
359
    def _try_load_model_cls(self,
                            model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in self.models:
            return None
360

361
        return _try_load_model_cls(model_arch, self.models[model_arch])
362

363
364
365
    def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
        if model_arch not in self.models:
            return None
366

367
        return _try_inspect_model_cls(model_arch, self.models[model_arch])
368

369
370
371
372
    def _normalize_archs(
        self,
        architectures: Union[str, List[str]],
    ) -> List[str]:
373
374
375
376
377
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

378
        return architectures
379

380
381
382
383
384
    def inspect_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> _ModelInfo:
        architectures = self._normalize_archs(architectures)
385

386
387
388
389
        for arch in architectures:
            model_info = self._try_inspect_model_cls(arch)
            if model_info is not None:
                return model_info
390

391
        return self._raise_for_unsupported(architectures)
392

393
394
395
396
397
    def resolve_model_cls(
        self,
        architectures: Union[str, List[str]],
    ) -> Tuple[Type[nn.Module], str]:
        architectures = self._normalize_archs(architectures)
398

399
400
401
402
        for arch in architectures:
            model_cls = self._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)
403

404
        return self._raise_for_unsupported(architectures)
405

406
407
408
409
410
    def is_text_generation_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_text_generation_model
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    def is_embedding_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).is_embedding_model

    def is_multimodal_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_multimodal

    def is_pp_supported_model(
        self,
        architectures: Union[str, List[str]],
    ) -> bool:
        return self.inspect_model_cls(architectures).supports_pp

430
431
432
433
434
435
436
437
    def model_has_inner_state(self, architectures: Union[str,
                                                         List[str]]) -> bool:
        return self.inspect_model_cls(architectures).has_inner_state

    def is_attention_free_model(self, architectures: Union[str,
                                                           List[str]]) -> bool:
        return self.inspect_model_cls(architectures).is_attention_free

438
439
440
441
442
443
444
445
446
447
448
449
450

ModelRegistry = _ModelRegistry({
    model_arch: _LazyRegisteredModel(
        module_name=f"vllm.model_executor.models.{mod_relname}",
        class_name=cls_name,
    )
    for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})

_T = TypeVar("_T")


def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
451
452
453
454
455
    # NOTE: We use a temporary directory instead of a temporary file to avoid
    # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
    with tempfile.TemporaryDirectory() as tempdir:
        output_filepath = os.path.join(tempdir, "registry_output.tmp")

456
        # `cloudpickle` allows pickling lambda functions directly
457
        input_bytes = cloudpickle.dumps((fn, output_filepath))
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

        # cannot use `sys.executable __file__` here because the script
        # contains relative imports
        returned = subprocess.run(
            [sys.executable, "-m", "vllm.model_executor.models.registry"],
            input=input_bytes,
            capture_output=True)

        # check if the subprocess is successful
        try:
            returned.check_returncode()
        except Exception as e:
            # wrap raised exception to provide more information
            raise RuntimeError(f"Error raised in subprocess:\n"
                               f"{returned.stderr.decode()}") from e

474
        with open(output_filepath, "rb") as f:
475
476
477
478
479
480
481
482
483
484
485
            return pickle.load(f)


def _run() -> None:
    # Setup plugins
    from vllm.plugins import load_general_plugins
    load_general_plugins()

    fn, output_file = pickle.loads(sys.stdin.buffer.read())

    result = fn()
486
487
488

    with open(output_file, "wb") as f:
        f.write(pickle.dumps(result))
489
490
491


if __name__ == "__main__":
492
    _run()