norm.py 1.65 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

thangvu's avatar
thangvu committed
3

ThangVu's avatar
ThangVu committed
4
norm_cfg = {
5
    # format: layer_type: (abbreviation, module)
ThangVu's avatar
ThangVu committed
6
    'BN': ('bn', nn.BatchNorm2d),
7
    'SyncBN': ('bn', nn.SyncBatchNorm),
ThangVu's avatar
ThangVu committed
8
9
10
    'GN': ('gn', nn.GroupNorm),
    # and potentially 'SN'
}
Kai Chen's avatar
Kai Chen committed
11
12


ThangVu's avatar
ThangVu committed
13
def build_norm_layer(cfg, num_features, postfix=''):
ThangVu's avatar
ThangVu committed
14
15
16
17
18
19
    """ Build normalization layer

    Args:
        cfg (dict): cfg should contain:
            type (str): identify norm layer type.
            layer args: args needed to instantiate a norm layer.
20
21
22
            requires_grad (bool): [optional] whether stop gradient updates
        num_features (int): number of channels from input.
        postfix (int, str): appended into norm abbreviation to
ThangVu's avatar
ThangVu committed
23
24
25
            create named layer.

    Returns:
26
        name (str): abbreviation + postfix
ThangVu's avatar
ThangVu committed
27
        layer (nn.Module): created norm layer
thangvu's avatar
thangvu committed
28
    """
Kai Chen's avatar
Kai Chen committed
29
30
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
31

Kai Chen's avatar
Kai Chen committed
32
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
33
34
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
ThangVu's avatar
ThangVu committed
35
36
37
38
39
40
41
    else:
        abbr, norm_layer = norm_cfg[layer_type]
        if norm_layer is None:
            raise NotImplementedError

    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
Kai Chen's avatar
Kai Chen committed
42

43
    requires_grad = cfg_.pop('requires_grad', True)
ThangVu's avatar
ThangVu committed
44
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
45
46
    if layer_type != 'GN':
        layer = norm_layer(num_features, **cfg_)
47
48
        if layer_type == 'SyncBN':
            layer._specify_ddp_gpu_num(1)
ThangVu's avatar
ThangVu committed
49
50
51
    else:
        assert 'num_groups' in cfg_
        layer = norm_layer(num_channels=num_features, **cfg_)
Kai Chen's avatar
Kai Chen committed
52

53
54
    for param in layer.parameters():
        param.requires_grad = requires_grad
ThangVu's avatar
ThangVu committed
55
56

    return name, layer