model_loader.py 4.13 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
import contextlib
3
4
from typing import Type

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import torch.nn as nn
7
from transformers import PretrainedConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.model_executor.models import *  # pylint: disable=wildcard-import
11
12
from vllm.model_executor.weight_utils import (get_quant_config,
                                              initialize_dummy_weights)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
shunxing1234's avatar
shunxing1234 committed
16
    "AquilaModel": AquilaForCausalLM,
17
    "AquilaForCausalLM": AquilaForCausalLM,  # AquilaChat2
18
19
    "BaiChuanForCausalLM": BaiChuanForCausalLM,  # baichuan-7b
    "BaichuanForCausalLM": BaichuanForCausalLM,  # baichuan-13b
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    "BloomForCausalLM": BloomForCausalLM,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
21
    "ChatGLMModel": ChatGLMForCausalLM,
Zhuohan Li's avatar
Zhuohan Li committed
22
    "FalconForCausalLM": FalconForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    "GPT2LMHeadModel": GPT2LMHeadModel,
24
    "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
25
    "GPTJForCausalLM": GPTJForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
26
    "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
Jia Guoqing's avatar
Jia Guoqing committed
27
    "InternLMForCausalLM": InternLMForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    "LlamaForCausalLM": LlamaForCausalLM,
29
    "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
30
    "MistralForCausalLM": MistralForCausalLM,
31
    # transformers's mpt class has lower case
32
33
    "MptForCausalLM": MPTForCausalLM,
    "MPTForCausalLM": MPTForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
    "OPTForCausalLM": OPTForCausalLM,
Qing's avatar
Qing committed
35
    "QWenLMHeadModel": QWenLMHeadModel,
Zhuohan Li's avatar
Zhuohan Li committed
36
    "RWForCausalLM": FalconForCausalLM,
Roy's avatar
Roy committed
37
    "YiForCausalLM": YiForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
}

40

41
42
43
44
45
46
47
48
49
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


50
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
55
56
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        if arch in _MODEL_REGISTRY:
            return _MODEL_REGISTRY[arch]
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
57
        f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59


60
61
def get_model(model_config: ModelConfig) -> nn.Module:
    model_class = _get_model_architecture(model_config.hf_config)
62

63
64
    # Get the (maybe quantized) linear method.
    linear_method = None
65
66
67
68
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config.quantization,
                                        model_config.model,
                                        model_config.download_dir)
69
70
71
72
73
74
75
76
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
77
78
79
80
81
82
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
83
        linear_method = quant_config.get_linear_method()
84

85
86
87
    with _set_default_torch_dtype(model_config.dtype):
        # Create a model instance.
        # The weights will be initialized as empty tensors.
88
        model = model_class(model_config.hf_config, linear_method)
89
        if model_config.load_format == "dummy":
90
91
92
93
94
95
96
            model = model.cuda()
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        else:
            # Load the weights from the cached or downloaded files.
            model.load_weights(model_config.model, model_config.download_dir,
Jasmond L's avatar
Jasmond L committed
97
                               model_config.load_format, model_config.revision)
98
            model = model.cuda()
99
    return model.eval()