weight_module.py 7.67 KB
Newer Older
1
2
3
4
5
class WeightModule:
    def __init__(self):
        self._modules = {}
        self._parameters = {}

6
7
8
    def is_empty(self):
        return len(self._modules) == 0 and len(self._parameters) == 0

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    def add_module(self, name, module):
        self._modules[name] = module
        setattr(self, name, module)

    def register_parameter(self, name, param):
        self._parameters[name] = param
        setattr(self, name, param)

    def load(self, weight_dict):
        for _, module in self._modules.items():
            if hasattr(module, "load"):
                module.load(weight_dict)

        for _, parameter in self._parameters.items():
            if hasattr(parameter, "load"):
                parameter.load(weight_dict)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    def calculate_size(self):
        total_size = 0
        for _, module in self._modules.items():
            if hasattr(module, "_calculate_size"):
                total_size += module._calculate_size()

        for _, parameter in self._parameters.items():
            if hasattr(parameter, "_calculate_size"):
                total_size += parameter._calculate_size()
        return total_size

    def load_from_disk(self):
        for _, module in self._modules.items():
            if hasattr(module, "load_from_disk"):
                module.load_from_disk()

        for _, parameter in self._parameters.items():
            if hasattr(parameter, "load_from_disk"):
                parameter.load_from_disk()

    def clear(self):
        for _, module in self._modules.items():
            if hasattr(module, "clear"):
                module.clear()

        for _, parameter in self._parameters.items():
            if hasattr(parameter, "clear"):
                parameter.clear()

55
    def state_dict(self, destination=None):
56
57
        if destination is None:
            destination = {}
58
        for _, param in self._parameters.items():
59
            if param is not None:
60
                param.state_dict(destination)
61
        for _, module in self._modules.items():
62
            if module is not None:
63
                module.state_dict(destination)
64
65
        return destination

66
67
68
69
70
71
72
73
74
75
76
    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if destination is None:
            destination = {}
        for _, param in self._parameters.items():
            if param is not None:
                param.load_state_dict(destination, block_index, adapter_block_index)
        for _, module in self._modules.items():
            if module is not None:
                module.load_state_dict(destination, block_index, adapter_block_index)
        return destination

77
78
79
80
81
82
83
84
85
86
    def named_parameters(self, prefix=""):
        for name, param in self._parameters.items():
            if param is not None:
                yield prefix + name, param
        for name, module in self._modules.items():
            if module is not None:
                yield from module.named_parameters(prefix + name + ".")

    def to_cpu(self):
        for name, param in self._parameters.items():
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
87
88
89
90
91
92
93
            if param is not None:
                if hasattr(param, "cpu"):
                    self._parameters[name] = param.cpu()
                    setattr(self, name, self._parameters[name])
                elif hasattr(param, "to_cpu"):
                    self._parameters[name].to_cpu()
                    setattr(self, name, self._parameters[name])
94
        for module in self._modules.values():
95
96
97
98
99
            if isinstance(module, WeightModuleList):
                for i in range(len(module)):
                    for m in module[i]._modules.values():
                        if m is not None and hasattr(m, "to_cpu"):
                            m.to_cpu()
100
101
102
                    for m in module[i]._parameters.values():
                        if m is not None and hasattr(m, "to_cpu"):
                            m.to_cpu()
103
104
105
            else:
                if module is not None and hasattr(module, "to_cpu"):
                    module.to_cpu()
106
107
108

    def to_cuda(self):
        for name, param in self._parameters.items():
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
109
110
111
112
113
            if param is not None:
                if hasattr(param, "cuda"):
                    self._parameters[name] = param.cuda()
                elif hasattr(param, "to_cuda"):
                    self._parameters[name].to_cuda()
114
115
                setattr(self, name, self._parameters[name])
        for module in self._modules.values():
116
117
118
119
120
            if isinstance(module, WeightModuleList):
                for i in range(len(module)):
                    for m in module[i]._modules.values():
                        if m is not None and hasattr(m, "to_cuda"):
                            m.to_cuda()
121
122
123
                    for m in module[i]._parameters.values():
                        if m is not None and hasattr(m, "to_cuda"):
                            m.to_cuda()
124
125
126
127
128
            else:
                if module is not None and hasattr(module, "to_cuda"):
                    module.to_cuda()

    def to_cpu_async(self):
129
        for name, param in self._parameters.items():
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
130
131
132
133
134
135
136
            if param is not None:
                if hasattr(param, "cpu"):
                    self._parameters[name] = param.cpu(non_blocking=True)
                    setattr(self, name, self._parameters[name])
                elif hasattr(param, "to_cpu"):
                    self._parameters[name].to_cpu(non_blocking=True)
                    setattr(self, name, self._parameters[name])
137
        for module in self._modules.values():
138
139
140
141
142
            if isinstance(module, WeightModuleList):
                for i in range(len(module)):
                    for m in module[i]._modules.values():
                        if m is not None and hasattr(m, "to_cpu"):
                            m.to_cpu(non_blocking=True)
143
144
145
                    for m in module[i]._parameters.values():
                        if m is not None and hasattr(m, "to_cpu"):
                            m.to_cpu(non_blocking=True)
146
147
148
149
150
            else:
                if module is not None and hasattr(module, "to_cpu"):
                    module.to_cpu(non_blocking=True)

    def to_cuda_async(self):
151
        for name, param in self._parameters.items():
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
152
153
154
155
156
            if param is not None:
                if hasattr(param, "cuda"):
                    self._parameters[name] = param.cuda(non_blocking=True)
                elif hasattr(param, "to_cuda"):
                    self._parameters[name].to_cuda(non_blocking=True)
157
158
                setattr(self, name, self._parameters[name])
        for module in self._modules.values():
159
160
161
162
163
            if isinstance(module, WeightModuleList):
                for i in range(len(module)):
                    for m in module[i]._modules.values():
                        if m is not None and hasattr(m, "to_cuda"):
                            m.to_cuda(non_blocking=True)
164
165
166
                    for m in module[i]._parameters.values():
                        if m is not None and hasattr(m, "to_cuda"):
                            m.to_cuda(non_blocking=True)
167
168
169
            else:
                if module is not None and hasattr(module, "to_cuda"):
                    module.to_cuda(non_blocking=True)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


class WeightModuleList(WeightModule):
    def __init__(self, modules=None):
        super().__init__()
        self._list = []
        if modules is not None:
            for idx, module in enumerate(modules):
                self.append(module)

    def append(self, module):
        idx = len(self._list)
        self._list.append(module)
        self.add_module(str(idx), module)

    def __getitem__(self, idx):
        return self._list[idx]

188
189
190
191
    def __setitem__(self, idx, module):
        self._list[idx] = module
        self.add_module(str(idx), module)

192
193
194
195
196
    def __len__(self):
        return len(self._list)

    def __iter__(self):
        return iter(self._list)