__init__.py 3.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import importlib
from typing import List, Optional, Type

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

# Architecture -> (module, class).
_MODELS = {
Roy's avatar
Roy committed
13
14
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
15
16
17
18
19
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
21
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
22
23
24
25
26
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
Roy's avatar
Roy committed
27
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
Fengzhe Zhou's avatar
Fengzhe Zhou committed
28
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
29
30
31
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
Roy's avatar
Roy committed
32
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
33
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
34
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
35
36
37
38
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
39
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
40
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
Junyang Lin's avatar
Junyang Lin committed
41
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
42
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
Hyunsung Lee's avatar
Hyunsung Lee committed
43
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
44
45
46
}

# Models not supported by ROCm.
47
_ROCM_UNSUPPORTED_MODELS = []
48
49
50
51

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
Junyang Lin's avatar
Junyang Lin committed
52
53
    "Qwen2ForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
54
55
    "MistralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
56
57
    "MixtralForCausalLM":
    "Sliding window attention is not yet supported in ROCm's flash attention",
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
}


class ModelRegistry:

    @staticmethod
    def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch not in _MODELS:
            return None
        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"Model architecture {model_arch} is partially supported "
                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

        module_name, model_cls_name = _MODELS[model_arch]
        module = importlib.import_module(
            f"vllm.model_executor.models.{module_name}")
        return getattr(module, model_cls_name, None)

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys())

Woosuk Kwon's avatar
Woosuk Kwon committed
86
87

__all__ = [
88
    "ModelRegistry",
Woosuk Kwon's avatar
Woosuk Kwon committed
89
]