__init__.py 4.08 KB
Newer Older
1
2
3
4
5
6
import importlib
from typing import List, Optional, Type

import torch.nn as nn

from vllm.logger import init_logger
7
from vllm.utils import is_hip, is_neuron
8
9
10
11
12

logger = init_logger(__name__)

# Architecture -> (module, class).
_MODELS = {
13
14
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
15
16
17
18
19
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
21
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
22
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
Xiang Xu's avatar
Xiang Xu committed
23
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
24
25
26
27
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
28
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
Fengzhe Zhou's avatar
Fengzhe Zhou committed
29
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
30
31
32
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
33
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
34
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
35
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
36
37
38
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Isotr0py's avatar
Isotr0py committed
39
    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
40
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
张大成's avatar
张大成 committed
41
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
42
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
43
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
Junyang Lin's avatar
Junyang Lin committed
44
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
45
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Hyunsung Lee's avatar
Hyunsung Lee committed
46
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
Roy's avatar
Roy committed
47
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
48
49
50
}

# Models not supported by ROCm.
51
_ROCM_UNSUPPORTED_MODELS = []
52
53
54
55

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
Junyang Lin's avatar
Junyang Lin committed
56
57
    "Qwen2ForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
58
59
    "MistralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
60
61
    "MixtralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
62
63
}

64
65
66
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

class ModelRegistry:

    @staticmethod
    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in _MODELS:
            return None
        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"Model architecture {model_arch} is partially supported "
                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
83
84
85
86
87
        elif is_neuron():
            if model_arch not in _NEURON_SUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "Neuron for now.")
88
89

        module_name, model_cls_name = _MODELS[model_arch]
90
91
        if is_neuron():
            module_name = _NEURON_SUPPORTED_MODELS[model_arch]
92
93
94
95
96
97
98
99
        module = importlib.import_module(
            f"vllm.model_executor.models.{module_name}")
        return getattr(module, model_cls_name, None)

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys())

Woosuk Kwon's avatar
Woosuk Kwon committed
100
101

__all__ = [
102
    "ModelRegistry",
Woosuk Kwon's avatar
Woosuk Kwon committed
103
]