__init__.py 3.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import importlib
from typing import List, Optional, Type

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

# Architecture -> (module, class).
_MODELS = {
13
14
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
15
16
17
18
19
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
21
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
22
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
Xiang Xu's avatar
Xiang Xu committed
23
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
24
25
26
27
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
28
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
Fengzhe Zhou's avatar
Fengzhe Zhou committed
29
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
30
31
32
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
33
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
34
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
35
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
36
37
38
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Isotr0py's avatar
Isotr0py committed
39
    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
40
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
41
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
42
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
Junyang Lin's avatar
Junyang Lin committed
43
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
44
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Hyunsung Lee's avatar
Hyunsung Lee committed
45
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
46
47
48
}

# Models not supported by ROCm.
49
_ROCM_UNSUPPORTED_MODELS = []
50
51
52
53

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
Junyang Lin's avatar
Junyang Lin committed
54
55
    "Qwen2ForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
56
57
    "MistralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
58
59
    "MixtralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
}


class ModelRegistry:

    @staticmethod
    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in _MODELS:
            return None
        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"Model architecture {model_arch} is partially supported "
                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

        module_name, model_cls_name = _MODELS[model_arch]
        module = importlib.import_module(
            f"vllm.model_executor.models.{module_name}")
        return getattr(module, model_cls_name, None)

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys())

Woosuk Kwon's avatar
Woosuk Kwon committed
88
89

__all__ = [
90
    "ModelRegistry",
Woosuk Kwon's avatar
Woosuk Kwon committed
91
]