norm.py 1.65 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
norm_cfg = {
5
    # format: layer_type: (abbreviation, module)
ThangVu's avatar
ThangVu committed
6
7
8
9
10
    'BN': ('bn', nn.BatchNorm2d),
    'SyncBN': ('bn', None),
    'GN': ('gn', nn.GroupNorm),
    # and potentially 'SN'
}
Kai Chen's avatar
Kai Chen committed
11
12


ThangVu's avatar
ThangVu committed
13
def build_norm_layer(cfg, num_features, postfix=''):
ThangVu's avatar
ThangVu committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    """ Build normalization layer

    Args:
        cfg (dict): cfg should contain:
            type (str): identify norm layer type.
            layer args: args needed to instantiate a norm layer.
            frozen (bool): [optional] whether stop gradient updates
                of norm layer, it is helpful to set frozen mode
                in backbone's norms.
        num_features (int): number of channels from input
        postfix (int, str): appended into norm abbreation to
            create named layer.

    Returns:
        name (str): abbreation + postfix
        layer (nn.Module): created norm layer
thangvu's avatar
thangvu committed
30
    """
Kai Chen's avatar
Kai Chen committed
31
32
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
33

Kai Chen's avatar
Kai Chen committed
34
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
35
36
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
ThangVu's avatar
ThangVu committed
37
38
39
40
41
42
43
    else:
        abbr, norm_layer = norm_cfg[layer_type]
        if norm_layer is None:
            raise NotImplementedError

    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
Kai Chen's avatar
Kai Chen committed
44

ThangVu's avatar
ThangVu committed
45
46
    frozen = cfg_.pop('frozen', False)
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
47
48
49
50
51
    if layer_type != 'GN':
        layer = norm_layer(num_features, **cfg_)
    else:
        assert 'num_groups' in cfg_
        layer = norm_layer(num_channels=num_features, **cfg_)
Kai Chen's avatar
Kai Chen committed
52

thangvu's avatar
thangvu committed
53
    if frozen:
ThangVu's avatar
ThangVu committed
54
        for param in layer.parameters():
thangvu's avatar
thangvu committed
55
            param.requires_grad = False
ThangVu's avatar
ThangVu committed
56
57

    return name, layer