norm.py 1.64 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

ThangVu's avatar
ThangVu committed
3
norm_cfg = {
4
    # format: layer_type: (abbreviation, module)
ThangVu's avatar
ThangVu committed
5
    'BN': ('bn', nn.BatchNorm2d),
6
    'SyncBN': ('bn', nn.SyncBatchNorm),
ThangVu's avatar
ThangVu committed
7
8
9
    'GN': ('gn', nn.GroupNorm),
    # and potentially 'SN'
}
Kai Chen's avatar
Kai Chen committed
10
11


ThangVu's avatar
ThangVu committed
12
def build_norm_layer(cfg, num_features, postfix=''):
ThangVu's avatar
ThangVu committed
13
14
15
16
17
18
    """ Build normalization layer

    Args:
        cfg (dict): cfg should contain:
            type (str): identify norm layer type.
            layer args: args needed to instantiate a norm layer.
19
20
21
            requires_grad (bool): [optional] whether stop gradient updates
        num_features (int): number of channels from input.
        postfix (int, str): appended into norm abbreviation to
ThangVu's avatar
ThangVu committed
22
23
24
            create named layer.

    Returns:
25
        name (str): abbreviation + postfix
ThangVu's avatar
ThangVu committed
26
        layer (nn.Module): created norm layer
thangvu's avatar
thangvu committed
27
    """
Kai Chen's avatar
Kai Chen committed
28
29
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
ThangVu's avatar
ThangVu committed
30

Kai Chen's avatar
Kai Chen committed
31
    layer_type = cfg_.pop('type')
ThangVu's avatar
ThangVu committed
32
33
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
ThangVu's avatar
ThangVu committed
34
35
36
37
38
39
40
    else:
        abbr, norm_layer = norm_cfg[layer_type]
        if norm_layer is None:
            raise NotImplementedError

    assert isinstance(postfix, (int, str))
    name = abbr + str(postfix)
Kai Chen's avatar
Kai Chen committed
41

42
    requires_grad = cfg_.pop('requires_grad', True)
ThangVu's avatar
ThangVu committed
43
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
44
45
    if layer_type != 'GN':
        layer = norm_layer(num_features, **cfg_)
46
47
        if layer_type == 'SyncBN':
            layer._specify_ddp_gpu_num(1)
ThangVu's avatar
ThangVu committed
48
49
50
    else:
        assert 'num_groups' in cfg_
        layer = norm_layer(num_channels=num_features, **cfg_)
Kai Chen's avatar
Kai Chen committed
51

52
53
    for param in layer.parameters():
        param.requires_grad = requires_grad
ThangVu's avatar
ThangVu committed
54
55

    return name, layer