Commit f64c9561 authored by thangvu's avatar thangvu
Browse files

revise group norm

parent 5b9a200e
......@@ -11,7 +11,9 @@ model = dict(
style='pytorch',
normalize=dict(
type='GN',
num_groups=32)),
num_groups=32,
eval=False,
frozen=False)),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
......
import logging
import pickle
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
......@@ -33,7 +31,8 @@ class BasicBlock(nn.Module):
downsample=None,
style='pytorch',
with_cp=False,
normalize=dict(type='BN')):
normalize=dict(type='BN'),
frozen=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
......@@ -52,6 +51,10 @@ class BasicBlock(nn.Module):
self.dilation = dilation
assert not with_cp
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
identity = x
......@@ -82,7 +85,8 @@ class Bottleneck(nn.Module):
downsample=None,
style='pytorch',
with_cp=False,
normalize=dict(type='BN')):
normalize=dict(type='BN'),
frozen=False):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
......@@ -130,6 +134,10 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp
self.normalize = normalize
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
def _inner_forward(x):
......@@ -171,7 +179,8 @@ def make_res_layer(block,
dilation=1,
style='pytorch',
with_cp=False,
normalize=dict(type='BN')):
normalize=dict(type='BN'),
frozen=False):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -194,7 +203,8 @@ def make_res_layer(block,
downsample,
style=style,
with_cp=with_cp,
normalize=normalize))
normalize=normalize,
frozen=frozen))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
......@@ -218,9 +228,9 @@ class ResNet(nn.Module):
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
normalize (dict): dictionary to construct norm layer. Additionally,
eval mode and gradent freezing are controlled by
eval (bool) and frozen (bool) respectively.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
......@@ -243,8 +253,8 @@ class ResNet(nn.Module):
frozen_stages=-1,
normalize=dict(
type='BN',
bn_eval=True,
bn_frozen=False),
eval=True,
frozen=False),
with_cp=False):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
......@@ -258,32 +268,17 @@ class ResNet(nn.Module):
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.frozen_stages = frozen_stages
self.with_cp = with_cp
assert isinstance(normalize, dict) and 'type' in normalize
assert normalize['type'] in ['BN', 'GN']
if normalize['type'] == 'GN':
assert 'num_groups' in normalize
else:
assert (set(['type', 'bn_eval', 'bn_frozen'])
== set(normalize))
if normalize['type'] == 'BN':
self.bn_eval = normalize['bn_eval']
self.bn_frozen = normalize['bn_frozen']
self.is_frozen = [i <= frozen_stages for i in range(num_stages + 1)]
assert (isinstance(normalize, dict) and 'eval' in normalize
and 'frozen' in normalize)
self.norm_eval = normalize.pop('eval')
self.normalize = normalize
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(normalize, 64)
self.stem_norm_name = 'gn1' if normalize['type'] == 'GN' else 'bn1'
self.add_module(self.stem_norm_name, stem_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self._make_stem_layer()
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
......@@ -299,7 +294,8 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp,
normalize=normalize)
normalize=normalize,
frozen=self.is_frozen[i + 1])
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
......@@ -308,6 +304,20 @@ class ResNet(nn.Module):
self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1)
def _make_stem_layer(self):
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(self.normalize, 64)
self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1'
self.add_module(self.norm_name, stem_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.is_frozen[0]:
for layer in [self.conv1, stem_norm]:
for param in layer.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
......@@ -316,7 +326,8 @@ class ResNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
elif (isinstance(m, nn.BatchNorm2d)
or isinstance(m, nn.GroupNorm)):
constant_init(m, 1)
# zero init for last norm layer https://arxiv.org/abs/1706.02677
......@@ -329,7 +340,7 @@ class ResNet(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = getattr(self, self.stem_norm_name)(x)
x = getattr(self, self.norm_name)(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
......@@ -345,121 +356,8 @@ class ResNet(nn.Module):
def train(self, mode=True):
super(ResNet, self).train(mode)
if self.normalize['type'] == 'BN' and self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
stem_norm = getattr(self, self.stem_norm_name)
stem_norm.eval()
for param in stem_norm.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
class ResNetClassifier(ResNet):
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
normalize=dict(
type='BN',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False),
with_cp=False,
num_classes=1000):
super(ResNetClassifier, self).__init__(depth,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
normalize=normalize,
with_cp=with_cp)
_, self.stage_blocks = self.arch_settings[depth]
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
expansion = 1 if depth == 18 else 4
self.fc = nn.Linear(512 * expansion, num_classes)
self.init_weights()
# TODO can be removed after tested
def load_caffe2_weight(self, cf_path):
norm = 'gn' if self.normalize['type'] == 'GN' else 'bn'
mapping = {}
for layer, blocks_in_layer in enumerate(self.stage_blocks, 1):
for blk in range(blocks_in_layer):
cf_prefix = 'res%d_%d_' % (layer + 1, blk)
py_prefix = 'layer%d.%d.' % (layer, blk)
# conv branch
for i, a in zip([1, 2, 3], ['a', 'b', 'c']):
cf_full = cf_prefix + 'branch2%s_' % a
mapping[py_prefix + 'conv%d.weight' % i] = cf_full + 'w'
mapping[py_prefix + norm + '%d.weight' % i] \
= cf_full + norm + '_s'
mapping[py_prefix + norm + '%d.bias' % i] \
= cf_full + norm + '_b'
# downsample branch
cf_full = 'res%d_0_branch1_' % (layer + 1)
py_full = 'layer%d.0.downsample.' % layer
mapping[py_full + '0.weight'] = cf_full + 'w'
mapping[py_full + '1.weight'] = cf_full + norm + '_s'
mapping[py_full + '1.bias'] = cf_full + norm + '_b'
# stem layers and last fc layer
if self.normalize['type'] == 'GN':
mapping['conv1.weight'] = 'conv1_w'
mapping['gn1.weight'] = 'conv1_gn_s'
mapping['gn1.bias'] = 'conv1_gn_b'
mapping['fc.weight'] = 'pred_w'
mapping['fc.bias'] = 'pred_b'
else:
mapping['conv1.weight'] = 'conv1_w'
mapping['bn1.weight'] = 'res_conv1_bn_s'
mapping['bn1.bias'] = 'res_conv1_bn_b'
mapping['fc.weight'] = 'fc1000_w'
mapping['fc.bias'] = 'fc1000_b'
# load state dict
py_state = self.state_dict()
with open(cf_path, 'rb') as f:
cf_state = pickle.load(f, encoding='latin1')
if 'blobs' in cf_state:
cf_state = cf_state['blobs']
for i, (py_k, cf_k) in enumerate(mapping.items(), 1):
print('[{}/{}] Loading {} to {}'.format(
i, len(mapping), cf_k, py_k))
assert py_k in py_state and cf_k in cf_state
py_state[py_k] = torch.Tensor(cf_state[cf_k])
self.load_state_dict(py_state)
def forward(self, x):
x = self.conv1(x)
x = getattr(self, self.stem_norm_name)(x)
x = self.relu(x)
x = self.maxpool(x)
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
if mode and self.norm_eval:
for mod in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(self, nn.BatchNorm2d):
mod.eval()
import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
def build_norm_layer(cfg, num_features):
"""
cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
frozen = cfg_.pop('frozen') if 'frozen' in cfg_ else False
# args name matching
if layer_type == 'GN':
assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features)
elif layer_type == 'BN':
cfg_ = dict() # rewrite neccessary info for BN from here
cfg_.setdefault('num_features', num_features)
cfg_.setdefault('eps', 1e-5)
......@@ -22,4 +31,8 @@ def build_norm_layer(cfg, num_features):
elif norm_cfg[layer_type] is None:
raise NotImplementedError
return norm_cfg[layer_type](**cfg_)
norm = norm_cfg[layer_type](**cfg_)
if frozen:
for param in norm.parameters():
param.requires_grad = False
return norm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment