mix_optimizer.py 3.44 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from torch.optim import Optimizer

from .registry import OPTIMIZERS


@OPTIMIZERS.register_module
class MixedOptimizer(Optimizer):
    """Mixed Optimizer that contains multiple optimizers

    This optimizer applies the cocktail optimzation for multi-modality models.

    """

    def __init__(self, optimizers, step_intervals=None):
        self.optimizers = optimizers
        self.param_groups = []
        for optimizer in self.optimizers:
            self.param_groups += optimizer.param_groups
        if not isinstance(step_intervals, list):
            step_intervals = [1] * len(self.optimizers)
        self.step_intervals = step_intervals
        self.num_step_updated = 0

    def __getstate__(self):
        return {
            'num_step_updated':
            self.num_step_updated,
            'defaults': [optimizer.defaults for optimizer in self.optimizers],
            'state': [optimizer.state for optimizer in self.optimizers],
            'param_groups':
            [optimizer.param_groups for optimizer in self.optimizers],
        }

    def __setstate__(self, state):
        self.__dict__.update(state)

    def __repr__(self):
        format_string = self.__class__.__name__ + ' (\n'
        for optimizer in self.optimizers:
            format_string += '\t' + optimizer.__repr__ + ',\n'
        format_string += ')'
        return format_string

    def state_dict(self):
        state_dicts = [optimizer.state_dict() for optimizer in self.optimizers]
        return {
            'num_step_updated':
            self.num_step_updated,
            'state': [state_dict['state'] for state_dict in state_dicts],
            'param_groups':
            [state_dict['param_groups'] for state_dict in state_dicts],
        }

    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        assert len(state_dict['state']) == len(self.optimizers)
        assert len(state_dict['param_groups']) == len(self.optimizers)
        for i, (single_state, single_param_groups) in enumerate(
                zip(state_dict['state'], state_dict['param_groups'])):
            single_state_dict = dict(
                state=single_state, param_groups=single_param_groups)
            self.optimizers[i].load_state_dict(single_state_dict)

        self.param_groups = []
        for optimizer in self.optimizers:
            self.param_groups += optimizer.param_groups
        self.num_step_updated = state_dict['num_step_updated']

    def zero_grad(self):
        r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
        for optimizer in self.optimizers:
            optimizer.zero_grad()

    def step(self, closure=None):
        r"""Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        loss = None
        if closure is not None:
            loss = closure()

        self.num_step_updated += 1
        for step_interval, optimizer in zip(self.step_intervals,
                                            self.optimizers):
            if self.num_step_updated % step_interval == 0:
                optimizer.step()

        return loss

    def add_param_group(self, param_group):
        raise NotImplementedError