norm.py 1.17 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
Kai Chen's avatar
Kai Chen committed
5
6
7


def build_norm_layer(cfg, num_features):
thangvu's avatar
thangvu committed
8
9
10
11
12
13
14
15
    """
    cfg should contain:
        type (str): identify norm layer type.
        layer args: args needed to instantiate a norm layer.
        frozen (bool): [optional] whether stop gradient updates
            of norm layer, it is helpful to set frozen mode
            in backbone's norms.
    """
Kai Chen's avatar
Kai Chen committed
16
17
18
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
    layer_type = cfg_.pop('type')
thangvu's avatar
thangvu committed
19
    frozen = cfg_.pop('frozen') if 'frozen' in cfg_ else False
Kai Chen's avatar
Kai Chen committed
20

ThangVu's avatar
ThangVu committed
21
22
    # args name matching
    if layer_type == 'GN':
23
        assert 'num_groups' in cfg
ThangVu's avatar
ThangVu committed
24
        cfg_.setdefault('num_channels', num_features)
25
    elif layer_type == 'BN':
ThangVu's avatar
ThangVu committed
26
        cfg_.setdefault('num_features', num_features)
27
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
28

Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    elif norm_cfg[layer_type] is None:
        raise NotImplementedError

thangvu's avatar
thangvu committed
34
35
36
37
38
    norm = norm_cfg[layer_type](**cfg_)
    if frozen:
        for param in norm.parameters():
            param.requires_grad = False
    return norm