utils.py 4.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utils for model executor."""
4
import copy
5
from typing import Any, Optional
6
7
8
9
10

import torch


def set_random_seed(seed: int) -> None:
11
    from vllm.platforms import current_platform
12
    current_platform.seed_everything(seed)
13
14
15
16


def set_weight_attrs(
    weight: torch.Tensor,
17
    weight_attrs: Optional[dict[str, Any]],
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")
33
34
35
36
37
38
39
40
41
42

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
43
        from vllm.platforms import current_platform
44
45
        if current_platform.is_tpu() and key == "weight_loader":
            value = _make_synced_weight_loader(value)
46
        setattr(weight, key, value)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):  
    if weight.dim() == 1:  
        padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)  
        padded_weight = torch.cat([weight, padding], dim=0)  
    elif weight.dim() == 2:   
        if pad_dim == 0:  
            padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=0)  
        elif pad_dim == 1:  
            padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=1)  
        else:  
            raise ValueError("pad_dim must be 0 or 1")  
    else:  
        raise ValueError("Weight tensor must be 1D or 2D")   
    padded_weight = padded_weight.contiguous()
    return padded_weight  


def gemm_bank_conf(weight):  
    is_mul_of_2048 = weight % 2048 == 0     
    is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0  
      
    if is_mul_of_2048 and is_power_of_two:  
        return True 
    else:  
zhuwenwen's avatar
zhuwenwen committed
75
76
        return False  

77
78
79
80
def _make_synced_weight_loader(original_weight_loader):

    def _synced_weight_loader(param, *args, **kwargs):
        original_weight_loader(param, *args, **kwargs)
81
82
83
        # torch._sync doesn't support, is not needed for CPU tensors.
        if param.device != torch.device("cpu"):
            torch._sync(param)
84
85

    return _synced_weight_loader
86
87
88


def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
89
90
    parent_map = getattr(model, "packed_modules_mapping", None)
    parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
91
92
93
94
95
96
97

    # don't infer mapping if the model has defined it explicitly.
    if parent_map:
        return parent_map

    # We only check main components instead of whole model submodules
    for child in model.children():
98
99
100
        child_map = getattr(child, "packed_modules_mapping", None)
        child_map = copy.deepcopy(child_map) if child_map is not None else {}

101
102
103
104
105
106
107
108
        if any((k in parent_map and parent_map[k] != v)
               for k, v in child_map.items()):
            raise ValueError(
                f"Can't update {type(model).__name__}'s packed_modules_mapping "
                f"safely because of conflicts from {type(child).__name__}.")
        else:
            parent_map.update(child_map)
    return parent_map