params.py 4.28 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""

# based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py

from collections import OrderedDict

import torch


class ContiguousParams:

    def __init__(self, parameters):
        # Create a list of the parameters to prevent emptying an iterator.
        self._parameters = parameters
        self._param_buffer = []
        self._grad_buffer = []
        self._group_dict = OrderedDict()
        self._name_buffer = []
        self._init_buffers()
        # Store the data pointers for each parameter into the buffer. These
        # can be used to check if an operation overwrites the gradient/data
        # tensor (invalidating the assumption of a contiguous buffer).
        self.data_pointers = []
        self.grad_pointers = []
        self.make_params_contiguous()

    def _init_buffers(self):
        dtype = self._parameters[0]["params"][0].dtype
        device = self._parameters[0]["params"][0].device
        if not all(p["params"][0].dtype == dtype for p in self._parameters):
            raise ValueError("All parameters must be of the same dtype.")
        if not all(p["params"][0].device == device for p in self._parameters):
            raise ValueError("All parameters must be on the same device.")

        # Group parameters by lr and weight decay
        for param_dict in self._parameters:
            freeze_status = param_dict["freeze_status"]
            param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"])
            if param_key not in self._group_dict:
                self._group_dict[param_key] = []
            self._group_dict[param_key].append(param_dict)

        for key, params in self._group_dict.items():
            size = sum(p["params"][0].numel() for p in params)
            self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device))
            self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device))
            self._name_buffer.append(key)

    def make_params_contiguous(self):
        """Create a buffer to hold all params and update the params to be views of the buffer.
        Args:
            parameters: An iterable of parameters.
        """
        for i, params in enumerate(self._group_dict.values()):
            index = 0
            for param_dict in params:
                p = param_dict["params"][0]
                size = p.numel()
                self._param_buffer[i][index:index + size] = p.data.view(-1)
                p.data = self._param_buffer[i][index:index + size].view(p.data.shape)
                p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape)
                self.data_pointers.append(p.data.data_ptr)
                self.grad_pointers.append(p.grad.data.data_ptr)
                index += size
            # Bend the param_buffer to use grad_buffer to track its gradients.
            self._param_buffer[i].grad = self._grad_buffer[i]

    def contiguous(self):
        """Return all parameters as one contiguous buffer."""
        return [{
            "freeze_status": self._name_buffer[i].split('_')[0],
            "params": self._param_buffer[i],
            "lr": float(self._name_buffer[i].split('_')[1]),
            "weight_decay": float(self._name_buffer[i].split('_')[2]),
        } for i in range(len(self._param_buffer))]

    def original(self):
        """Return the non-flattened parameters."""
        return self._parameters

    def buffer_is_valid(self):
        """Verify that all parameters and gradients still use the buffer."""
        i = 0
        for params in self._group_dict.values():
            for param_dict in params:
                p = param_dict["params"][0]
                data_ptr = self.data_pointers[i]
                grad_ptr = self.grad_pointers[i]
                if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()):
                    return False
                i += 1
        return True

    def assert_buffer_is_valid(self):
        if not self.buffer_is_valid():
            raise ValueError(
                "The data or gradient buffer has been invalidated. Please make "
                "sure to use inplace operations only when updating parameters "
                "or gradients.")