__init__.py 3.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import importlib
from typing import List, Optional, Type

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

# Architecture -> (module, class).
_MODELS = {
13
14
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
15
16
17
18
19
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
21
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
22
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
Xiang Xu's avatar
Xiang Xu committed
23
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
24
25
26
27
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
28
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
Fengzhe Zhou's avatar
Fengzhe Zhou committed
29
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
30
31
32
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
33
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
34
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
35
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
36
37
38
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Isotr0py's avatar
Isotr0py committed
39
    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
40
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
张大成's avatar
张大成 committed
41
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
42
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
43
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
Junyang Lin's avatar
Junyang Lin committed
44
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
45
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Hyunsung Lee's avatar
Hyunsung Lee committed
46
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
Roy's avatar
Roy committed
47
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
48
49
50
}

# Models not supported by ROCm.
51
_ROCM_UNSUPPORTED_MODELS = []
52
53
54
55

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
Junyang Lin's avatar
Junyang Lin committed
56
57
    "Qwen2ForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
58
59
    "MistralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
60
61
    "MixtralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
}


class ModelRegistry:

    @staticmethod
    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in _MODELS:
            return None
        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"Model architecture {model_arch} is partially supported "
                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

        module_name, model_cls_name = _MODELS[model_arch]
        module = importlib.import_module(
            f"vllm.model_executor.models.{module_name}")
        return getattr(module, model_cls_name, None)

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys())

Woosuk Kwon's avatar
Woosuk Kwon committed
90
91

__all__ = [
92
    "ModelRegistry",
Woosuk Kwon's avatar
Woosuk Kwon committed
93
]