norm.py 668 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn

ThangVu's avatar
ThangVu committed
3
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
Kai Chen's avatar
Kai Chen committed
4
5
6
7
8
9
10
11


def build_norm_layer(cfg, num_features):
    assert isinstance(cfg, dict) and 'type' in cfg
    cfg_ = cfg.copy()
    cfg_.setdefault('eps', 1e-5)
    layer_type = cfg_.pop('type')

ThangVu's avatar
ThangVu committed
12
13
14
15
16
17
    # args name matching
    if layer_type == 'GN':
        cfg_.setdefault('num_channels', num_features)
    else:
        cfg_.setdefault('num_features', num_features)

Kai Chen's avatar
Kai Chen committed
18
19
20
21
22
    if layer_type not in norm_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    elif norm_cfg[layer_type] is None:
        raise NotImplementedError

ThangVu's avatar
ThangVu committed
23
    return norm_cfg[layer_type](**cfg_)