utils.py 2.92 KB
Newer Older
1
"""Utils for model executor."""
2
from typing import Any, Dict, Optional
3
4
5

import torch

6
from vllm.platforms import current_platform
7
8
from vllm.utils import seed_everything

9
10

def set_random_seed(seed: int) -> None:
11
    seed_everything(seed)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")
32
33
34
35
36
37
38
39
40
41
42
43

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
        if current_platform.is_tpu() and key == "weight_loader":
            value = _make_synced_weight_loader(value)
44
        setattr(weight, key, value)
45
46


47
48
49
50
51
52
53
def _make_synced_weight_loader(original_weight_loader):

    def _synced_weight_loader(param, *args, **kwargs):
        original_weight_loader(param, *args, **kwargs)
        torch._sync(param)

    return _synced_weight_loader
54
55


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):  
    if weight.dim() == 1:  
        padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)  
        padded_weight = torch.cat([weight, padding], dim=0)  
    elif weight.dim() == 2:   
        if pad_dim == 0:  
            padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=0)  
        elif pad_dim == 1:  
            padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)  
            padded_weight = torch.cat([weight, padding], dim=1)  
        else:  
            raise ValueError("pad_dim must be 0 or 1")  
    else:  
        raise ValueError("Weight tensor must be 1D or 2D")   
    padded_weight = padded_weight.contiguous()
    return padded_weight  


def gemm_bank_conf(weight):  
    is_mul_of_2048 = weight % 2048 == 0     
    is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0  
      
    if is_mul_of_2048 and is_power_of_two:  
        return True 
    else:  
82
83
        return False