__init__.py 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import importlib
from typing import List, Optional, Type

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

# Architecture -> (module, class).
_MODELS = {
    "AquilaModel": ("aquila", "AquilaForCausalLM"),
    "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"),  # AquilaChat2
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
    "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
36
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
37
38
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Hyunsung Lee's avatar
Hyunsung Lee committed
39
40
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "YiForCausalLM": ("yi", "YiForCausalLM")
41
42
43
}

# Models not supported by ROCm.
44
_ROCM_UNSUPPORTED_MODELS = []
45
46
47
48
49
50

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
    "MistralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
51
52
    "MixtralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
}


class ModelRegistry:

    @staticmethod
    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in _MODELS:
            return None
        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"Model architecture {model_arch} is partially supported "
                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

        module_name, model_cls_name = _MODELS[model_arch]
        module = importlib.import_module(
            f"vllm.model_executor.models.{module_name}")
        return getattr(module, model_cls_name, None)

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys())

Woosuk Kwon's avatar
Woosuk Kwon committed
81
82

__all__ = [
83
    "ModelRegistry",
Woosuk Kwon's avatar
Woosuk Kwon committed
84
]