utils.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utilities for selecting and loading models."""
4

5
6
7
import inspect
import warnings
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
10
11

import torch
from torch import nn
12
from typing_extensions import assert_never
13

14
from vllm.attention import Attention
15
from vllm.attention.layer import MLAAttention
16
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization.base_config import (
19
20
21
    QuantizationConfig,
    QuantizeMethodBase,
)
22
from vllm.model_executor.models.adapters import (
23
24
25
26
27
28
    as_embedding_model,
    as_reward_model,
    as_seq_cls_model,
    try_create_mm_pooling_model_cls,
)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
29
from vllm.utils.platform_utils import is_pin_memory_available
30

31
32
logger = init_logger(__name__)

33

34
35
36
37
def initialize_model(
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
38
39
    model_class: type[nn.Module] | None = None,
    model_config: ModelConfig | None = None,
40
41
) -> nn.Module:
    """Initialize a model with the given configurations."""
42
43
    if model_config is None:
        model_config = vllm_config.model_config
44
45
46
47
48
49
50
51
52
53
    if model_class is None:
        model_class, _ = get_model_architecture(model_config)

    if vllm_config.quant_config is not None:
        configure_quant_config(vllm_config.quant_config, model_class)

    signatures = inspect.signature(model_class.__init__)
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
54
        with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
55
56
            return model_class(vllm_config=vllm_config, prefix=prefix)

57
58
59
60
61
62
63
    msg = (
        "vLLM model class should accept `vllm_config` and `prefix` as "
        "input arguments. Possibly you have an old-style model class"
        " registered from out of tree and it is used for new vLLM version. "
        "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
        "for the design and update the model class accordingly."
    )
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
        model_class,
    )
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
84
    with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
85
86
87
        return model_class(**kwargs)


88
89
90
def process_weights_after_loading(
    model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
91
92
    # to avoid circular dependency
    from vllm.model_executor.model_loader.online_quantization import (
93
94
95
96
        maybe_save_metadata_and_attributes_for_weight_reloading,
    )

    maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
97

98
99
100
101
102
103
104
105
106
107
108
    for _, module in model.named_modules():
        quant_method = getattr(module, "quant_method", None)
        if isinstance(quant_method, QuantizeMethodBase):
            # When quant methods need to process weights after loading
            # (for repacking, quantizing, etc), they expect parameters
            # to be on the global target device. This scope is for the
            # case where cpu offloading is used, where we will move the
            # parameters onto device for processing and back off after.
            with device_loading_context(module, target_device):
                quant_method.process_weights_after_loading(module)

109
110
    # Initialize post-load attention weights for both Attention and MLA.
    # NOTE: Happens after other modules so we can easily decompress weights.
111
    for _, module in model.named_modules():
112
        if isinstance(module, (Attention, MLAAttention)) and hasattr(
113
114
            module, "process_weights_after_loading"
        ):
115
116
117
118
119
120
            # TODO(lucas): see if there is a way to unify the signatures
            # of process_weights_after_loading
            module.process_weights_after_loading(model_config.dtype)


@contextmanager
121
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
122
123
124
125
126
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

127
    original_device_states: dict[str, torch.device] = {}
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched


162
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
163
164
165
"""Caches the outputs of `_get_model_architecture`."""


166
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
167
    architectures = getattr(model_config.hf_config, "architectures", [])
168

169
170
171
172
173
174
    model_cls, arch = model_config.registry.resolve_model_cls(
        architectures,
        model_config=model_config,
    )

    if arch == model_config._get_transformers_backend_cls():
175
176
        assert model_config.model_impl != "vllm"
        if model_config.model_impl == "auto":
177
178
179
            logger.warning_once(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
180
181
182
                "performance may not be optimal.",
                arch,
            )
183
184

    convert_type = model_config.convert_type
185
186
187
188
189
190
191
192
193
    if convert_type != "none" and supports_multimodal(model_cls):
        logger.debug_once("Detected conversion of Multi Modal model.")
        converted = try_create_mm_pooling_model_cls(model_cls)
        if converted is not None:
            logger.debug_once("Creating wrapper class to forward pooler.")
            return converted, arch
        else:
            logger.debug_once("Attempting direct conversion.")

194
195
196
197
    if convert_type == "none":
        pass
    elif convert_type == "embed":
        logger.debug_once("Converting to embedding model.")
198
        model_cls = as_embedding_model(model_cls)
199
200
    elif convert_type == "classify":
        logger.debug_once("Converting to sequence classification model.")
201
        model_cls = as_seq_cls_model(model_cls)
202
203
    elif convert_type == "reward":
        logger.debug_once("Converting to reward model.")
204
        model_cls = as_reward_model(model_cls)
205
206
    else:
        assert_never(convert_type)
207
208

    return model_cls, arch
209
210


211
212
213
214
215
216
217
218
219
220
221
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
    key = hash(
        (
            model_config.model,
            model_config.convert_type,
            model_config.runner_type,
            model_config.trust_remote_code,
            model_config.model_impl,
            tuple(getattr(model_config.hf_config, "architectures", [])),
        )
    )
222
223
224
225
226
227
228
229
    if key in _MODEL_ARCH_BY_HASH:
        return _MODEL_ARCH_BY_HASH[key]

    model_arch = _get_model_architecture(model_config)
    _MODEL_ARCH_BY_HASH[key] = model_arch
    return model_arch


230
231
232
233
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
    return get_model_architecture(model_config)[0]


234
235
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
236
237
238
239
240
241


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
242
    It creates a bidirectional mapping between packed parameters and their
243
244
    constituent parts.
    """
245

246
    packed_mapping: dict[str, list[str]]
247
    inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
248
249
250
251
252
253
254
255
256
257
258

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
259

260
    def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
261
262
263
264
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
265
266


267
268
269
def configure_quant_config(
    quant_config: QuantizationConfig, model_class: type[nn.Module]
):
270
271
272
273
274
275
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
276
277
278

    Once the `SupportsQuant` mixin has been added to all models, this
    function can be removed
279
    """
280
281
282
283
284
285
286
287
288
    if not issubclass(model_class, SupportsQuant):
        hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
        packed_mapping = getattr(model_class, "packed_modules_mapping", None)

        # pass mappings by reference to quant_config
        if hf_to_vllm_mapper is not None:
            quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
        if packed_mapping is not None:
            quant_config.packed_modules_mapping = packed_mapping