norm.py 1.33 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
5
6
7
8
9
10
norm_cfg = {
    # format: layer_type: (abbreation, module)
    'BN': ('bn', nn.BatchNorm2d),
    'SyncBN': ('bn', None),
    'GN': ('gn', nn.GroupNorm),
    # and potentially 'SN'
}
Kai Chen's avatar
Kai Chen committed
11
12


ThangVu's avatar
ThangVu committed
13
def build_norm_layer(cfg, num_features, postfix=''):
thangvu's avatar
thangvu committed
14
15
16
17
18
19
20
21
    """
    cfg should contain:
        type (str): identify norm layer type.
        layer args: args needed to instantiate a norm layer.
        frozen (bool): [optional] whether stop gradient updates
            of norm layer, it is helpful to set frozen mode
            in backbone's norms.
    """
Kai Chen's avatar
Kai Chen committed
22
23
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
24

Kai Chen's avatar
Kai Chen committed
25
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
26
27
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
ThangVu's avatar
ThangVu committed
28
29
30
31
32
33
34
    else:
        abbr, norm_layer = norm_cfg[layer_type]
        if norm_layer is None:
            raise NotImplementedError

    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
Kai Chen's avatar
Kai Chen committed
35

ThangVu's avatar
ThangVu committed
36
37
    frozen = cfg_.pop('frozen', False)
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
38
39
40
41
42
    if layer_type != 'GN':
        layer = norm_layer(num_features, **cfg_)
    else:
        assert 'num_groups' in cfg_
        layer = norm_layer(num_channels=num_features, **cfg_)
Kai Chen's avatar
Kai Chen committed
43

thangvu's avatar
thangvu committed
44
    if frozen:
ThangVu's avatar
ThangVu committed
45
        for param in layer.parameters():
thangvu's avatar
thangvu committed
46
            param.requires_grad = False
ThangVu's avatar
ThangVu committed
47
48

    return name, layer