norm.py 1.2 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
Kai Chen's avatar
Kai Chen committed
5
6
7


def build_norm_layer(cfg, num_features):
thangvu's avatar
thangvu committed
8
9
10
11
12
13
14
15
    """
    cfg should contain:
        type (str): identify norm layer type.
        layer args: args needed to instantiate a norm layer.
        frozen (bool): [optional] whether stop gradient updates
            of norm layer, it is helpful to set frozen mode
            in backbone's norms.
    """
Kai Chen's avatar
Kai Chen committed
16
17
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
18

Kai Chen's avatar
Kai Chen committed
19
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
20
21
22
23
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    elif norm_cfg[layer_type] is None:
        raise NotImplementedError
Kai Chen's avatar
Kai Chen committed
24

ThangVu's avatar
ThangVu committed
25
    frozen = cfg_.pop('frozen', False)
ThangVu's avatar
ThangVu committed
26
    # args name matching
ThangVu's avatar
ThangVu committed
27
    if layer_type in ['GN']:
28
        assert 'num_groups' in cfg
ThangVu's avatar
ThangVu committed
29
        cfg_.setdefault('num_channels', num_features)
ThangVu's avatar
ThangVu committed
30
    elif layer_type in ['BN']:
ThangVu's avatar
ThangVu committed
31
        cfg_.setdefault('num_features', num_features)
ThangVu's avatar
ThangVu committed
32
    else:
Kai Chen's avatar
Kai Chen committed
33
        raise NotImplementedError
ThangVu's avatar
ThangVu committed
34
    cfg_.setdefault('eps', 1e-5)
Kai Chen's avatar
Kai Chen committed
35

thangvu's avatar
thangvu committed
36
37
38
39
40
    norm = norm_cfg[layer_type](**cfg_)
    if frozen:
        for param in norm.parameters():
            param.requires_grad = False
    return norm