utils.py 5.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utils for model executor."""
4

5
import copy
6
from typing import Any
7
8
9

import torch

10
from vllm.utils.torch_utils import is_torch_equal_or_newer
11

12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):  
    if weight.dim() == 1:  
        padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)  
        padded_weight = torch.cat([weight, padding], dim=0)  
    elif weight.dim() == 2:   
        if pad_dim == 0:  
            padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=0)  
        elif pad_dim == 1:  
            padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=1)  
        else:  
            raise ValueError("pad_dim must be 0 or 1")  
    else:  
        raise ValueError("Weight tensor must be 1D or 2D")   
    padded_weight = padded_weight.contiguous()
    return padded_weight  


def gemm_bank_conf(weight):  
    is_mul_of_2048 = weight % 2048 == 0     
    is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0  
      
    if is_mul_of_2048 and is_power_of_two:  
        return True 
    else:  
        return False  
40
41
42
43


def set_weight_attrs(
    weight: torch.Tensor,
44
    weight_attrs: dict[str, Any] | None,
45
46
47
48
49
50
51
52
53
54
55
56
57
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
58
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
59
60
61
62
63
64
65
66
67
68

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
69
        from vllm.platforms import current_platform
70

71
        if current_platform.use_sync_weight_loader() and key == "weight_loader":
72
            value = current_platform.make_synced_weight_loader(value)
73
        setattr(weight, key, value)
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
    """
    Replace a parameter of a layer while maintaining the ability to reload the weight.
    Called within implementations of the `process_weights_after_loading` method.

    This function should not be called on weights which are tied/shared

    Args:
        layer: Layer containing parameter to replace
        param_name: Name of parameter to replace
        new_data: New data of the new parameter
    """
    # should not be used on a tied/shared param
    if isinstance(new_data, torch.nn.Parameter):
        new_data = new_data.data
    new_param = torch.nn.Parameter(new_data, requires_grad=False)

    old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
    if old_param is not None and hasattr(old_param, "weight_loader"):
        weight_loader = old_param.weight_loader
        set_weight_attrs(new_param, {"weight_loader": weight_loader})

    setattr(layer, param_name, new_param)


101
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
102
103
    parent_map = getattr(model, "packed_modules_mapping", None)
    parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
104
105
106
107
108
109
110

    # don't infer mapping if the model has defined it explicitly.
    if parent_map:
        return parent_map

    # We only check main components instead of whole model submodules
    for child in model.children():
111
112
113
        child_map = getattr(child, "packed_modules_mapping", None)
        child_map = copy.deepcopy(child_map) if child_map is not None else {}

114
        if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()):
115
116
            raise ValueError(
                f"Can't update {type(model).__name__}'s packed_modules_mapping "
117
118
                f"safely because of conflicts from {type(child).__name__}."
            )
119
120
        else:
            parent_map.update(child_map)
121
122
123
124
    return parent_map


def get_moe_expert_mapping(
125
126
    model: torch.nn.Module,
) -> list[tuple[str, str, int, str]]:
127
128
129
130
131
132
133
134
135
    if parent_map := getattr(model, "get_expert_mapping", None):
        return parent_map()
    else:
        # We only check main components instead of whole model submodules
        for child in model.children():
            child_map = getattr(child, "get_expert_mapping", None)
            if child_map is not None:
                return child_map()
        return []
136
137
138
139
140
141
142


def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]:
    if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"):
        return {"graph_partition": False}
    else:
        return {}