model_loader.py 5.16 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
import contextlib
3
4
from typing import Type

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import torch.nn as nn
7
from transformers import PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.config import ModelConfig
10
from vllm.model_executor.models import *
11
12
from vllm.model_executor.weight_utils import (get_quant_config,
                                              initialize_dummy_weights)
13
14
15
16
from vllm.utils import is_hip
from vllm.logger import init_logger

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
shunxing1234's avatar
shunxing1234 committed
20
    "AquilaModel": AquilaForCausalLM,
21
    "AquilaForCausalLM": AquilaForCausalLM,  # AquilaChat2
22
23
    "BaiChuanForCausalLM": BaiChuanForCausalLM,  # baichuan-7b
    "BaichuanForCausalLM": BaichuanForCausalLM,  # baichuan-13b
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    "BloomForCausalLM": BloomForCausalLM,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
25
    "ChatGLMModel": ChatGLMForCausalLM,
26
    "ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
Zhuohan Li's avatar
Zhuohan Li committed
27
    "FalconForCausalLM": FalconForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    "GPT2LMHeadModel": GPT2LMHeadModel,
29
    "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
30
    "GPTJForCausalLM": GPTJForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
Jia Guoqing's avatar
Jia Guoqing committed
32
    "InternLMForCausalLM": InternLMForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
    "LlamaForCausalLM": LlamaForCausalLM,
34
    "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
35
    "MistralForCausalLM": MistralForCausalLM,
Pierre Stock's avatar
Pierre Stock committed
36
    "MixtralForCausalLM": MixtralForCausalLM,
37
    # transformers's mpt class has lower case
38
39
    "MptForCausalLM": MPTForCausalLM,
    "MPTForCausalLM": MPTForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
40
    "OPTForCausalLM": OPTForCausalLM,
41
    "PhiForCausalLM": PhiForCausalLM,
Qing's avatar
Qing committed
42
    "QWenLMHeadModel": QWenLMHeadModel,
Zhuohan Li's avatar
Zhuohan Li committed
43
    "RWForCausalLM": FalconForCausalLM,
Roy's avatar
Roy committed
44
    "YiForCausalLM": YiForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
}

47
48
49
50
51
52
53
54
55
56
57
58
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
    for rocm_model in _ROCM_UNSUPPORTED_MODELS:
        del _MODEL_REGISTRY[rocm_model]

# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
    "MistralForCausalLM":
    "Sliding window attention is not supported in ROCm's flash attention",
}

59

60
61
62
63
64
65
66
67
68
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


69
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MODEL_REGISTRY:
73
74
75
76
            if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    f"{arch} is not fully supported in ROCm. Reason: "
                    f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
Woosuk Kwon's avatar
Woosuk Kwon committed
77
            return _MODEL_REGISTRY[arch]
78
79
80
81
        elif arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(
                f"Model architecture {arch} is not supported by ROCm for now. \n"
                f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
84
        f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86


87
88
def get_model(model_config: ModelConfig) -> nn.Module:
    model_class = _get_model_architecture(model_config.hf_config)
89

90
91
    # Get the (maybe quantized) linear method.
    linear_method = None
92
93
94
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config.quantization,
                                        model_config.model,
95
                                        model_config.hf_config,
96
                                        model_config.download_dir)
97
98
99
100
101
102
103
104
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
105
106
107
108
109
110
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
111
        linear_method = quant_config.get_linear_method()
112

113
114
115
    with _set_default_torch_dtype(model_config.dtype):
        # Create a model instance.
        # The weights will be initialized as empty tensors.
116
117
        with torch.device("cuda"):
            model = model_class(model_config.hf_config, linear_method)
118
        if model_config.load_format == "dummy":
119
120
121
122
123
124
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        else:
            # Load the weights from the cached or downloaded files.
            model.load_weights(model_config.model, model_config.download_dir,
Jasmond L's avatar
Jasmond L committed
125
                               model_config.load_format, model_config.revision)
126
    return model.eval()