norm.py 1.96 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
5
6
7
8
9
10
norm_cfg = {
    # format: layer_type: (abbreation, module)
    'BN': ('bn', nn.BatchNorm2d),
    'SyncBN': ('bn', None),
    'GN': ('gn', nn.GroupNorm),
    # and potentially 'SN'
}
Kai Chen's avatar
Kai Chen committed
11
12


ThangVu's avatar
ThangVu committed
13
def build_norm_layer(cfg, num_features, postfix=''):
ThangVu's avatar
ThangVu committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    """ Build normalization layer

    Args:
        cfg (dict): cfg should contain:
            type (str): identify norm layer type.
            layer args: args needed to instantiate a norm layer.
            frozen (bool): [optional] whether stop gradient updates
                of norm layer, it is helpful to set frozen mode
                in backbone's norms.
        num_features (int): number of channels from input
        postfix (int, str): appended into norm abbreation to
            create named layer.

    Returns:
        name (str): abbreation + postfix
        layer (nn.Module): created norm layer
thangvu's avatar
thangvu committed
30
    """
Kai Chen's avatar
Kai Chen committed
31
32
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
33

ThangVu's avatar
ThangVu committed
34
35
36
37
38
39
40
    # eval_mode is supported and popped out for processing in module
    # having pretrained weight only (e.g. backbone)
    # raise an exception if eval_mode is in here
    if 'eval_mode' in cfg:
        raise Exception('eval_mode for modules without pretrained weights '
                        'is not supported')

Kai Chen's avatar
Kai Chen committed
41
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
42
43
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
ThangVu's avatar
ThangVu committed
44
45
46
47
48
49
50
    else:
        abbr, norm_layer = norm_cfg[layer_type]
        if norm_layer is None:
            raise NotImplementedError

    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
Kai Chen's avatar
Kai Chen committed
51

ThangVu's avatar
ThangVu committed
52
53
    frozen = cfg_.pop('frozen', False)
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
54
55
56
57
58
    if layer_type != 'GN':
        layer = norm_layer(num_features, **cfg_)
    else:
        assert 'num_groups' in cfg_
        layer = norm_layer(num_channels=num_features, **cfg_)
Kai Chen's avatar
Kai Chen committed
59

thangvu's avatar
thangvu committed
60
    if frozen:
ThangVu's avatar
ThangVu committed
61
        for param in layer.parameters():
thangvu's avatar
thangvu committed
62
            param.requires_grad = False
ThangVu's avatar
ThangVu committed
63
64

    return name, layer