builder.py 5.89 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import re

import torch

from mmdet.utils import build_from_cfg, get_root_logger
from .registry import OPTIMIZERS


def build_optimizer(model, optimizer_cfg):
    """Build optimizer from configs.

    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are:
                - type: class name of the optimizer.
                - lr: base learning rate.
            Optional fields are:
                - any arguments of the corresponding optimizer type, e.g.,
                  weight_decay, momentum, etc.
                - paramwise_options: a dict with 4 accepted fileds
                  (bias_lr_mult, bias_decay_mult, norm_decay_mult,
                  dwconv_decay_mult).
                  `bias_lr_mult` and `bias_decay_mult` will be multiplied to
                  the lr and weight decay respectively for all bias parameters
                  (except for the normalization layers), and
                  `norm_decay_mult` will be multiplied to the weight decay
                  for all weight and bias parameters of normalization layers.
                  `dwconv_decay_mult` will be multiplied to the weight decay
                  for all weight and bias parameters of depthwise conv layers.

    Returns:
        torch.optim.Optimizer: The initialized optimizer.

    Example:
        >>> import torch
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001)
        >>> optimizer = build_optimizer(model, optimizer_cfg)
    """
    if hasattr(model, 'module'):
        model = model.module

    optimizer_cfg = optimizer_cfg.copy()

    if isinstance(optimizer_cfg, list):
        # Assume paramwise_options is None if optimizer_cfg is list
        from .mix_optimizer import MixedOptimizer
        logger = get_root_logger()
        keys = [optimizer.pop('key') for optimizer in optimizer_cfg]
        keys_params = {key: [] for key in keys}
        keys_params_name = {key: [] for key in keys}
        keys_optimizer = []
        for name, param in model.named_parameters():
            param_group = {'params': [param]}
            find_flag = False
            for key in keys:
                if key in name:
                    keys_params[key].append(param_group)
                    keys_params_name[key].append(name)
                    find_flag = True
                    break
            assert find_flag, 'key {} is not matched to any optimizer'.format(
                name)

        step_intervals = []
        for key, single_cfg in zip(keys, optimizer_cfg):
            optimizer_cls = getattr(torch.optim, single_cfg.pop('type'))
            step_intervals.append(single_cfg.pop('step_interval', 1))
            single_optim = optimizer_cls(keys_params[key], **single_cfg)
            keys_optimizer.append(single_optim)
            logger.info('{} optimizes key:\n {}\n'.format(
                optimizer_cls.__name__, keys_params_name[key]))

        mix_optimizer = MixedOptimizer(keys_optimizer, step_intervals)
        return mix_optimizer
    else:
        paramwise_options = optimizer_cfg.pop('paramwise_options', None)

    # if no paramwise option is specified, just use the global setting
    if paramwise_options is None:
        params = model.parameters()
    else:
        assert isinstance(paramwise_options, dict)
        # get base lr and weight decay
        base_lr = optimizer_cfg['lr']
        base_wd = optimizer_cfg.get('weight_decay', None)
        # weight_decay must be explicitly specified if mult is specified
        if ('bias_decay_mult' in paramwise_options
                or 'norm_decay_mult' in paramwise_options
                or 'dwconv_decay_mult' in paramwise_options):
            assert base_wd is not None
        # get param-wise options
        bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
        bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
        norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
        dwconv_decay_mult = paramwise_options.get('dwconv_decay_mult', 1.)
        named_modules = dict(model.named_modules())
        # set param-wise lr and weight decay
        params = []
        for name, param in model.named_parameters():
            param_group = {'params': [param]}
            if not param.requires_grad:
                # FP16 training needs to copy gradient/weight between master
                # weight copy and model weight, it is convenient to keep all
                # parameters here to align with model.parameters()
                params.append(param_group)
                continue

            # for norm layers, overwrite the weight decay of weight and bias
            # TODO: obtain the norm layer prefixes dynamically
            if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * norm_decay_mult
            # for other layers, overwrite both lr and weight decay of bias
            elif name.endswith('.bias'):
                param_group['lr'] = base_lr * bias_lr_mult
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * bias_decay_mult

            module_name = name.replace('.weight', '').replace('.bias', '')
            if module_name in named_modules and base_wd is not None:
                module = named_modules[module_name]
                # if this Conv2d is depthwise Conv2d
                if isinstance(module, torch.nn.Conv2d) and \
                        module.in_channels == module.groups:
                    param_group['weight_decay'] = base_wd * dwconv_decay_mult
            # otherwise use the global settings

            params.append(param_group)

    optimizer_cfg['params'] = params

    return build_from_cfg(optimizer_cfg, OPTIMIZERS)