model_loader.py 5.44 KB
Newer Older
1
"""Utilities for selecting and loading models."""
2
import contextlib
yhu422's avatar
yhu422 committed
3
from typing import Tuple, Type
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
6
from torch import nn
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
from vllm.config import DeviceConfig, ModelConfig
9
from vllm.model_executor.models import ModelRegistry
10
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
11
12
13
from vllm.model_executor.tensorizer_loader import (
    ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
    load_with_tensorizer)
14
15
from vllm.model_executor.weight_utils import (get_quant_config,
                                              initialize_dummy_weights)
16

17
18
19
20
_VISION_MODEL_CLASSES = [
    LlavaForConditionalGeneration,
]

21

22
23
24
25
26
27
28
29
30
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


yhu422's avatar
yhu422 committed
31
32
def _get_model_architecture(
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
33
34
35
36
37
38
39
    architectures = getattr(model_config.hf_config, "architectures", [])
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
    if (model_config.quantization is not None
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]

Woosuk Kwon's avatar
Woosuk Kwon committed
40
    for arch in architectures:
41
42
        model_cls = ModelRegistry.load_model_cls(arch)
        if model_cls is not None:
yhu422's avatar
yhu422 committed
43
            return (model_cls, arch)
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. "
46
        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48


yhu422's avatar
yhu422 committed
49
50
51
52
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return _get_model_architecture(model_config)[1]


53
54
55
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
              **kwargs) -> nn.Module:
    lora_config = kwargs.get("lora_config", None)
56
    vision_language_config = kwargs.get("vision_language_config", None)
57
    tensorizer_config = kwargs.get("tensorizer_config", None)
yhu422's avatar
yhu422 committed
58
    model_class = _get_model_architecture(model_config)[0]
59

60
61
    # Get the (maybe quantized) linear method.
    linear_method = None
62
    if model_config.quantization is not None:
63
        quant_config = get_quant_config(model_config)
64
65
66
67
68
69
70
71
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
72
73
74
75
76
77
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
78

79
        linear_method = quant_config.get_linear_method()
80

81
82
83
    with _set_default_torch_dtype(model_config.dtype):
        # Create a model instance.
        # The weights will be initialized as empty tensors.
84
85
86
87
88
89
90
91
92
93
94
95
        extra_kwargs = {}
        if hasattr(model_class, "supported_lora_modules"):
            extra_kwargs["lora_config"] = lora_config
        elif lora_config:
            raise ValueError(
                f"Model {model_class.__name__} does not support LoRA, "
                "but LoRA is enabled. Support for this model may "
                "be added in the future. If this is important to you, "
                "please open an issue on github.")
        elif model_class in _VISION_MODEL_CLASSES:
            extra_kwargs["vision_language_config"] = vision_language_config

96
        with torch.device(device_config.device):
97
98
99
100
101
102
103
104
105
106
107
            if (model_config.load_format == "tensorizer"
                    and is_vllm_serialized_tensorizer(tensorizer_config)):
                extra_kwargs["linear_method"] = linear_method
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype
                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
                return model.eval()
            model = model_class(config=model_config.hf_config,
                                linear_method=linear_method,
                                **extra_kwargs)
108
        if model_config.load_format == "dummy":
109
110
111
112
113
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        else:
            # Load the weights from the cached or downloaded files.
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            if model_config.load_format == "tensorizer":
                # Provide a dynamic load format for `model.load_weights`
                # to retain tensorizer args from CLI.
                model_config.load_format = ParameterizedLoadFormat(
                    model_config.load_format)
                model_config.load_format.params = (
                    tensorizer_config._construct_tensorizer_args())

            model.load_weights(
                model_config.model,
                model_config.download_dir,
                model_config.load_format,
                model_config.revision,
            )
128
    return model.eval()