utils.py 4.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utils for model executor."""
4

5
import copy
6
from typing import Any
7
8
9

import torch

10
from vllm.utils.torch_utils import is_torch_equal_or_newer
11

12

13
14
def set_weight_attrs(
    weight: torch.Tensor,
15
    weight_attrs: dict[str, Any] | None,
16
17
18
19
20
21
22
23
24
25
26
27
28
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
29
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
30
31
32
33
34
35
36
37
38
39

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
40
        from vllm.platforms import current_platform
41

42
        if current_platform.use_sync_weight_loader() and key == "weight_loader":
43
            value = current_platform.make_synced_weight_loader(value)
44
        setattr(weight, key, value)
45
46


47
48
49
def replace_parameter(
    layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None
):
50
51
52
53
54
55
56
57
58
    """
    Replace a parameter of a layer while maintaining the ability to reload the weight.
    Called within implementations of the `process_weights_after_loading` method.

    This function should not be called on weights which are tied/shared

    Args:
        layer: Layer containing parameter to replace
        param_name: Name of parameter to replace
59
        new_data: New data of the new parameter, or None to set the parameter to None
60
61
    """
    # should not be used on a tied/shared param
62
63
64
65
66
67

    # If new_data is None, set the parameter to None
    if new_data is None:
        setattr(layer, param_name, None)
        return

68
69
70
71
72
73
74
75
76
77
78
79
    if isinstance(new_data, torch.nn.Parameter):
        new_data = new_data.data
    new_param = torch.nn.Parameter(new_data, requires_grad=False)

    old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
    if old_param is not None and hasattr(old_param, "weight_loader"):
        weight_loader = old_param.weight_loader
        set_weight_attrs(new_param, {"weight_loader": weight_loader})

    setattr(layer, param_name, new_param)


80
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
81
82
    parent_map = getattr(model, "packed_modules_mapping", None)
    parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
83
84
85
86
87
88
89

    # don't infer mapping if the model has defined it explicitly.
    if parent_map:
        return parent_map

    # We only check main components instead of whole model submodules
    for child in model.children():
90
91
92
        child_map = getattr(child, "packed_modules_mapping", None)
        child_map = copy.deepcopy(child_map) if child_map is not None else {}

93
        if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()):
94
95
            raise ValueError(
                f"Can't update {type(model).__name__}'s packed_modules_mapping "
96
97
                f"safely because of conflicts from {type(child).__name__}."
            )
98
99
        else:
            parent_map.update(child_map)
100
101
102
103
    return parent_map


def get_moe_expert_mapping(
104
105
    model: torch.nn.Module,
) -> list[tuple[str, str, int, str]]:
106
107
108
109
110
111
112
113
114
    if parent_map := getattr(model, "get_expert_mapping", None):
        return parent_map()
    else:
        # We only check main components instead of whole model submodules
        for child in model.children():
            child_map = getattr(child, "get_expert_mapping", None)
            if child_map is not None:
                return child_map()
        return []
115
116
117
118
119
120
121


def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]:
    if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"):
        return {"graph_partition": False}
    else:
        return {}