utils.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utilities for selecting and loading models."""
4

5
6
7
import inspect
import warnings
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
10
11

import torch
from torch import nn
12
from typing_extensions import assert_never
13

14
from vllm.attention.layer import Attention, MLAAttention
15
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.quantization.base_config import (
18
19
20
    QuantizationConfig,
    QuantizeMethodBase,
)
21
from vllm.model_executor.models.interfaces import SupportsQuant
22
from vllm.utils.platform_utils import is_pin_memory_available
23

24
25
logger = init_logger(__name__)

26

27
28
29
30
def initialize_model(
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
31
32
    model_class: type[nn.Module] | None = None,
    model_config: ModelConfig | None = None,
33
34
) -> nn.Module:
    """Initialize a model with the given configurations."""
35
36
    if model_config is None:
        model_config = vllm_config.model_config
37
38
39
40
41
42
43
44
45
46
    if model_class is None:
        model_class, _ = get_model_architecture(model_config)

    if vllm_config.quant_config is not None:
        configure_quant_config(vllm_config.quant_config, model_class)

    signatures = inspect.signature(model_class.__init__)
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
47
        with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
48
49
            return model_class(vllm_config=vllm_config, prefix=prefix)

50
51
52
53
54
55
56
    msg = (
        "vLLM model class should accept `vllm_config` and `prefix` as "
        "input arguments. Possibly you have an old-style model class"
        " registered from out of tree and it is used for new vLLM version. "
        "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
        "for the design and update the model class accordingly."
    )
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
        model_class,
    )
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
77
    with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
78
79
80
        return model_class(**kwargs)


81
82
83
def process_weights_after_loading(
    model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
84
85
86
87
88
89
90
91
    if getattr(model, "process_weights_after_loading_already_called", False):
        # In case `process_weights_after_loading` is called multiple times
        # we'll skip it at later times
        logger.debug_once(
            "process_weights_after_loading already called for model %s", model
        )
        return

92
93
    # to avoid circular dependency
    from vllm.model_executor.model_loader.online_quantization import (
94
95
96
97
        maybe_save_metadata_and_attributes_for_weight_reloading,
    )

    maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
98

99
100
101
102
103
104
105
106
107
108
109
    for _, module in model.named_modules():
        quant_method = getattr(module, "quant_method", None)
        if isinstance(quant_method, QuantizeMethodBase):
            # When quant methods need to process weights after loading
            # (for repacking, quantizing, etc), they expect parameters
            # to be on the global target device. This scope is for the
            # case where cpu offloading is used, where we will move the
            # parameters onto device for processing and back off after.
            with device_loading_context(module, target_device):
                quant_method.process_weights_after_loading(module)

110
111
    # Initialize post-load attention weights for both Attention and MLA.
    # NOTE: Happens after other modules so we can easily decompress weights.
112
    for _, module in model.named_modules():
113
        if isinstance(module, (Attention, MLAAttention)) and hasattr(
114
115
            module, "process_weights_after_loading"
        ):
116
117
118
119
120
121
            # TODO(lucas): see if there is a way to unify the signatures
            # of process_weights_after_loading
            module.process_weights_after_loading(model_config.dtype)


@contextmanager
122
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
123
124
125
126
127
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

128
    original_device_states: dict[str, torch.device] = {}
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched


163
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
164
165
166
"""Caches the outputs of `_get_model_architecture`."""


167
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
168
    from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
Cyrus Leung's avatar
Cyrus Leung committed
169

170
    architectures = getattr(model_config.hf_config, "architectures", [])
171

172
173
174
175
176
177
    model_cls, arch = model_config.registry.resolve_model_cls(
        architectures,
        model_config=model_config,
    )

    if arch == model_config._get_transformers_backend_cls():
178
179
        assert model_config.model_impl != "vllm"
        if model_config.model_impl == "auto":
180
181
182
            logger.warning_once(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
183
184
185
                "performance may not be optimal.",
                arch,
            )
186
187
188
189
190
191

    convert_type = model_config.convert_type
    if convert_type == "none":
        pass
    elif convert_type == "embed":
        logger.debug_once("Converting to embedding model.")
192
        model_cls = as_embedding_model(model_cls)
193
194
    elif convert_type == "classify":
        logger.debug_once("Converting to sequence classification model.")
195
        model_cls = as_seq_cls_model(model_cls)
196
197
    else:
        assert_never(convert_type)
198
199

    return model_cls, arch
200
201


202
203
204
205
206
207
208
209
210
211
212
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
    key = hash(
        (
            model_config.model,
            model_config.convert_type,
            model_config.runner_type,
            model_config.trust_remote_code,
            model_config.model_impl,
            tuple(getattr(model_config.hf_config, "architectures", [])),
        )
    )
213
214
215
216
217
218
219
220
    if key in _MODEL_ARCH_BY_HASH:
        return _MODEL_ARCH_BY_HASH[key]

    model_arch = _get_model_architecture(model_config)
    _MODEL_ARCH_BY_HASH[key] = model_arch
    return model_arch


221
222
223
224
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
    return get_model_architecture(model_config)[0]


225
226
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
227
228
229
230
231
232


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
233
    It creates a bidirectional mapping between packed parameters and their
234
235
    constituent parts.
    """
236

237
    packed_mapping: dict[str, list[str]]
238
    inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
239
240
241
242
243
244
245
246
247
248
249

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
250

251
    def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
252
253
254
255
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
256
257


258
259
260
def configure_quant_config(
    quant_config: QuantizationConfig, model_class: type[nn.Module]
):
261
262
263
264
265
266
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
267
268
269

    Once the `SupportsQuant` mixin has been added to all models, this
    function can be removed
270
    """
271
272
273
274
275
276
277
278
279
    if not issubclass(model_class, SupportsQuant):
        hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
        packed_mapping = getattr(model_class, "packed_modules_mapping", None)

        # pass mappings by reference to quant_config
        if hf_to_vllm_mapper is not None:
            quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
        if packed_mapping is not None:
            quant_config.packed_modules_mapping = packed_mapping