ema.py 2.12 KB
Newer Older
Jinhua Zhu's avatar
Jinhua Zhu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from copy import deepcopy
from itertools import chain
from unicore.optim.fp16_optimizer import pad_numel
import torch


class ExponentialMovingAverageModel:
    def __init__(self, model, decay, init_param=None):
        self.model_ema = deepcopy(model).float()
        self.decay = decay
        self.param = self.flatten_parameters(model, init_param)

    def flatten_parameters(self, model, init_param):
        # get ordered name
        dtype_grouped_names = dict()
        ordered_dtype = []
        for n, p in model.named_parameters():
            if p.dtype not in dtype_grouped_names:
                dtype_grouped_names[p.dtype] = []
                ordered_dtype.append(p.dtype)
            dtype_grouped_names[p.dtype].append(n)

        ordered_names = list(chain(*(dtype_grouped_names[n] for n in ordered_dtype)))

        name2param = dict()
        for n, p in self.model_ema.named_parameters():
            name2param[n] = p
        cur_params = [name2param[n] for n in ordered_names]
        total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
        flatten_param = cur_params[0].new(0).float().new_zeros(total_param_size)

        offset = 0
        for p in cur_params:
            numel = p.data.numel()
            flatten_param[offset : offset + numel].copy_(p.data.view(-1))
            p.data = flatten_param.data[offset : offset + numel].view(*p.shape)
            offset += pad_numel(numel)
        flatten_param = torch.nn.Parameter(flatten_param)
        if init_param is not None:
            assert torch.allclose(init_param, flatten_param), "ema init error!"
        torch.cuda.empty_cache()
        return flatten_param

    def update(self, new_param):
        with torch.no_grad():
            diff = self.param - new_param
            diff *= 1 - self.decay
            self.param -= diff

    def load_state_dict(self, state_dict):
        self.model_ema.load_state_dict(state_dict["params"])
        self.decay = state_dict["decay"] if "decay" in state_dict else self.decay

    def state_dict(self):
        return {
            "params": self.model_ema.state_dict(),
            "decay": self.decay,
        }