norm.py 788 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

ThangVu's avatar
ThangVu committed
3
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
Kai Chen's avatar
Kai Chen committed
4
5
6
7
8
9
10


def build_norm_layer(cfg, num_features):
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
    layer_type = cfg_.pop('type')

ThangVu's avatar
ThangVu committed
11
12
    # args name matching
    if layer_type == 'GN':
13
        assert 'num_groups' in cfg
ThangVu's avatar
ThangVu committed
14
        cfg_.setdefault('num_channels', num_features)
15
16
    elif layer_type == 'BN':
        cfg_ = dict()  # rewrite neccessary info for BN from here
ThangVu's avatar
ThangVu committed
17
        cfg_.setdefault('num_features', num_features)
18
    cfg_.setdefault('eps', 1e-5)
ThangVu's avatar
ThangVu committed
19

Kai Chen's avatar
Kai Chen committed
20
21
22
23
24
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    elif norm_cfg[layer_type] is None:
        raise NotImplementedError

ThangVu's avatar
ThangVu committed
25
    return norm_cfg[layer_type](**cfg_)