cocktail_optimizer.py 3.57 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
from torch.optim import Optimizer

zhangwenwei's avatar
zhangwenwei committed
3
from mmdet.core.optimizer import OPTIMIZERS
zhangwenwei's avatar
zhangwenwei committed
4
5
6


@OPTIMIZERS.register_module
zhangwenwei's avatar
zhangwenwei committed
7
8
class CocktailOptimizer(Optimizer):
    """Cocktail Optimizer that contains multiple optimizers
zhangwenwei's avatar
zhangwenwei committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

    This optimizer applies the cocktail optimzation for multi-modality models.

    """

    def __init__(self, optimizers, step_intervals=None):
        self.optimizers = optimizers
        self.param_groups = []
        for optimizer in self.optimizers:
            self.param_groups += optimizer.param_groups
        if not isinstance(step_intervals, list):
            step_intervals = [1] * len(self.optimizers)
        self.step_intervals = step_intervals
        self.num_step_updated = 0

    def __getstate__(self):
        return {
            'num_step_updated':
            self.num_step_updated,
            'defaults': [optimizer.defaults for optimizer in self.optimizers],
            'state': [optimizer.state for optimizer in self.optimizers],
            'param_groups':
            [optimizer.param_groups for optimizer in self.optimizers],
        }

    def __setstate__(self, state):
        self.__dict__.update(state)

    def __repr__(self):
        format_string = self.__class__.__name__ + ' (\n'
zhangwenwei's avatar
zhangwenwei committed
39
40
41
        for optimizer, interval in zip(self.optimizers, self.step_intervals):
            format_string += 'Update interval: {}\n'.format(interval)
            format_string += optimizer.__repr__().replace('\n', '\n  ') + ',\n'
zhangwenwei's avatar
zhangwenwei committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        format_string += ')'
        return format_string

    def state_dict(self):
        state_dicts = [optimizer.state_dict() for optimizer in self.optimizers]
        return {
            'num_step_updated':
            self.num_step_updated,
            'state': [state_dict['state'] for state_dict in state_dicts],
            'param_groups':
            [state_dict['param_groups'] for state_dict in state_dicts],
        }

    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Arguments:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        assert len(state_dict['state']) == len(self.optimizers)
        assert len(state_dict['param_groups']) == len(self.optimizers)
        for i, (single_state, single_param_groups) in enumerate(
                zip(state_dict['state'], state_dict['param_groups'])):
            single_state_dict = dict(
                state=single_state, param_groups=single_param_groups)
            self.optimizers[i].load_state_dict(single_state_dict)

        self.param_groups = []
        for optimizer in self.optimizers:
            self.param_groups += optimizer.param_groups
        self.num_step_updated = state_dict['num_step_updated']

    def zero_grad(self):
        r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
        for optimizer in self.optimizers:
            optimizer.zero_grad()

    def step(self, closure=None):
        r"""Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        loss = None
        if closure is not None:
            loss = closure()

        self.num_step_updated += 1
        for step_interval, optimizer in zip(self.step_intervals,
                                            self.optimizers):
            if self.num_step_updated % step_interval == 0:
                optimizer.step()

        return loss

    def add_param_group(self, param_group):
        raise NotImplementedError