Commit f64c9561 authored by thangvu's avatar thangvu
Browse files

revise group norm

parent 5b9a200e
...@@ -11,7 +11,9 @@ model = dict( ...@@ -11,7 +11,9 @@ model = dict(
style='pytorch', style='pytorch',
normalize=dict( normalize=dict(
type='GN', type='GN',
num_groups=32)), num_groups=32,
eval=False,
frozen=False)),
neck=dict( neck=dict(
type='FPN', type='FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
......
import logging import logging
import pickle
import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
...@@ -33,7 +31,8 @@ class BasicBlock(nn.Module): ...@@ -33,7 +31,8 @@ class BasicBlock(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN')): normalize=dict(type='BN'),
frozen=False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
...@@ -52,6 +51,10 @@ class BasicBlock(nn.Module): ...@@ -52,6 +51,10 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x): def forward(self, x):
identity = x identity = x
...@@ -82,7 +85,8 @@ class Bottleneck(nn.Module): ...@@ -82,7 +85,8 @@ class Bottleneck(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN')): normalize=dict(type='BN'),
frozen=False):
"""Bottleneck block for ResNet. """Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer. if it is "caffe", the stride-two layer is the first 1x1 conv layer.
...@@ -130,6 +134,10 @@ class Bottleneck(nn.Module): ...@@ -130,6 +134,10 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp self.with_cp = with_cp
self.normalize = normalize self.normalize = normalize
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x): def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
...@@ -171,7 +179,8 @@ def make_res_layer(block, ...@@ -171,7 +179,8 @@ def make_res_layer(block,
dilation=1, dilation=1,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN')): normalize=dict(type='BN'),
frozen=False):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -194,7 +203,8 @@ def make_res_layer(block, ...@@ -194,7 +203,8 @@ def make_res_layer(block,
downsample, downsample,
style=style, style=style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize)) normalize=normalize,
frozen=frozen))
inplanes = planes * block.expansion inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append( layers.append(
...@@ -218,9 +228,9 @@ class ResNet(nn.Module): ...@@ -218,9 +228,9 @@ class ResNet(nn.Module):
the first 1x1 conv layer. the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze normalize (dict): dictionary to construct norm layer. Additionally,
running stats (mean and var). eval mode and gradent freezing are controlled by
bn_frozen (bool): Whether to freeze weight and bias of BN layers. eval (bool) and frozen (bool) respectively.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. memory while slowing down the training speed.
""" """
...@@ -243,8 +253,8 @@ class ResNet(nn.Module): ...@@ -243,8 +253,8 @@ class ResNet(nn.Module):
frozen_stages=-1, frozen_stages=-1,
normalize=dict( normalize=dict(
type='BN', type='BN',
bn_eval=True, eval=True,
bn_frozen=False), frozen=False),
with_cp=False): with_cp=False):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
...@@ -258,32 +268,17 @@ class ResNet(nn.Module): ...@@ -258,32 +268,17 @@ class ResNet(nn.Module):
self.out_indices = out_indices self.out_indices = out_indices
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
self.style = style self.style = style
self.frozen_stages = frozen_stages
self.with_cp = with_cp self.with_cp = with_cp
self.is_frozen = [i <= frozen_stages for i in range(num_stages + 1)]
assert isinstance(normalize, dict) and 'type' in normalize assert (isinstance(normalize, dict) and 'eval' in normalize
assert normalize['type'] in ['BN', 'GN'] and 'frozen' in normalize)
if normalize['type'] == 'GN': self.norm_eval = normalize.pop('eval')
assert 'num_groups' in normalize
else:
assert (set(['type', 'bn_eval', 'bn_frozen'])
== set(normalize))
if normalize['type'] == 'BN':
self.bn_eval = normalize['bn_eval']
self.bn_frozen = normalize['bn_frozen']
self.normalize = normalize self.normalize = normalize
self.block, stage_blocks = self.arch_settings[depth] self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d( self._make_stem_layer()
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(normalize, 64)
self.stem_norm_name = 'gn1' if normalize['type'] == 'GN' else 'bn1'
self.add_module(self.stem_norm_name, stem_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = [] self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks): for i, num_blocks in enumerate(self.stage_blocks):
...@@ -299,7 +294,8 @@ class ResNet(nn.Module): ...@@ -299,7 +294,8 @@ class ResNet(nn.Module):
dilation=dilation, dilation=dilation,
style=self.style, style=self.style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize) normalize=normalize,
frozen=self.is_frozen[i + 1])
self.inplanes = planes * self.block.expansion self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1) layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
...@@ -308,6 +304,20 @@ class ResNet(nn.Module): ...@@ -308,6 +304,20 @@ class ResNet(nn.Module):
self.feat_dim = self.block.expansion * 64 * 2**( self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1) len(self.stage_blocks) - 1)
def _make_stem_layer(self):
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(self.normalize, 64)
self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1'
self.add_module(self.norm_name, stem_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.is_frozen[0]:
for layer in [self.conv1, stem_norm]:
for param in layer.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
...@@ -316,7 +326,8 @@ class ResNet(nn.Module): ...@@ -316,7 +326,8 @@ class ResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m) kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d): elif (isinstance(m, nn.BatchNorm2d)
or isinstance(m, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
# zero init for last norm layer https://arxiv.org/abs/1706.02677 # zero init for last norm layer https://arxiv.org/abs/1706.02677
...@@ -329,7 +340,7 @@ class ResNet(nn.Module): ...@@ -329,7 +340,7 @@ class ResNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = getattr(self, self.stem_norm_name)(x) x = getattr(self, self.norm_name)(x)
x = self.relu(x) x = self.relu(x)
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
...@@ -345,121 +356,8 @@ class ResNet(nn.Module): ...@@ -345,121 +356,8 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
if self.normalize['type'] == 'BN' and self.bn_eval: if mode and self.norm_eval:
for m in self.modules(): for mod in self.modules():
if isinstance(m, nn.BatchNorm2d): # trick: eval have effect on BatchNorm only
m.eval() if isinstance(self, nn.BatchNorm2d):
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
stem_norm = getattr(self, self.stem_norm_name)
stem_norm.eval()
for param in stem_norm.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval() mod.eval()
for param in mod.parameters():
param.requires_grad = False
class ResNetClassifier(ResNet):
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
normalize=dict(
type='BN',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False),
with_cp=False,
num_classes=1000):
super(ResNetClassifier, self).__init__(depth,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
normalize=normalize,
with_cp=with_cp)
_, self.stage_blocks = self.arch_settings[depth]
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
expansion = 1 if depth == 18 else 4
self.fc = nn.Linear(512 * expansion, num_classes)
self.init_weights()
# TODO can be removed after tested
def load_caffe2_weight(self, cf_path):
norm = 'gn' if self.normalize['type'] == 'GN' else 'bn'
mapping = {}
for layer, blocks_in_layer in enumerate(self.stage_blocks, 1):
for blk in range(blocks_in_layer):
cf_prefix = 'res%d_%d_' % (layer + 1, blk)
py_prefix = 'layer%d.%d.' % (layer, blk)
# conv branch
for i, a in zip([1, 2, 3], ['a', 'b', 'c']):
cf_full = cf_prefix + 'branch2%s_' % a
mapping[py_prefix + 'conv%d.weight' % i] = cf_full + 'w'
mapping[py_prefix + norm + '%d.weight' % i] \
= cf_full + norm + '_s'
mapping[py_prefix + norm + '%d.bias' % i] \
= cf_full + norm + '_b'
# downsample branch
cf_full = 'res%d_0_branch1_' % (layer + 1)
py_full = 'layer%d.0.downsample.' % layer
mapping[py_full + '0.weight'] = cf_full + 'w'
mapping[py_full + '1.weight'] = cf_full + norm + '_s'
mapping[py_full + '1.bias'] = cf_full + norm + '_b'
# stem layers and last fc layer
if self.normalize['type'] == 'GN':
mapping['conv1.weight'] = 'conv1_w'
mapping['gn1.weight'] = 'conv1_gn_s'
mapping['gn1.bias'] = 'conv1_gn_b'
mapping['fc.weight'] = 'pred_w'
mapping['fc.bias'] = 'pred_b'
else:
mapping['conv1.weight'] = 'conv1_w'
mapping['bn1.weight'] = 'res_conv1_bn_s'
mapping['bn1.bias'] = 'res_conv1_bn_b'
mapping['fc.weight'] = 'fc1000_w'
mapping['fc.bias'] = 'fc1000_b'
# load state dict
py_state = self.state_dict()
with open(cf_path, 'rb') as f:
cf_state = pickle.load(f, encoding='latin1')
if 'blobs' in cf_state:
cf_state = cf_state['blobs']
for i, (py_k, cf_k) in enumerate(mapping.items(), 1):
print('[{}/{}] Loading {} to {}'.format(
i, len(mapping), cf_k, py_k))
assert py_k in py_state and cf_k in cf_state
py_state[py_k] = torch.Tensor(cf_state[cf_k])
self.load_state_dict(py_state)
def forward(self, x):
x = self.conv1(x)
x = getattr(self, self.stem_norm_name)(x)
x = self.relu(x)
x = self.maxpool(x)
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
import torch.nn as nn import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm} norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
def build_norm_layer(cfg, num_features): def build_norm_layer(cfg, num_features):
"""
cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
"""
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
frozen = cfg_.pop('frozen') if 'frozen' in cfg_ else False
# args name matching # args name matching
if layer_type == 'GN': if layer_type == 'GN':
assert 'num_groups' in cfg assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features) cfg_.setdefault('num_channels', num_features)
elif layer_type == 'BN': elif layer_type == 'BN':
cfg_ = dict() # rewrite neccessary info for BN from here
cfg_.setdefault('num_features', num_features) cfg_.setdefault('num_features', num_features)
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
...@@ -22,4 +31,8 @@ def build_norm_layer(cfg, num_features): ...@@ -22,4 +31,8 @@ def build_norm_layer(cfg, num_features):
elif norm_cfg[layer_type] is None: elif norm_cfg[layer_type] is None:
raise NotImplementedError raise NotImplementedError
return norm_cfg[layer_type](**cfg_) norm = norm_cfg[layer_type](**cfg_)
if frozen:
for param in norm.parameters():
param.requires_grad = False
return norm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment