Commit dca2d841 authored by ThangVu's avatar ThangVu
Browse files

revise group norm (2)

parent f64c9561
......@@ -12,7 +12,7 @@ model = dict(
normalize=dict(
type='GN',
num_groups=32,
eval=False,
eval_mode=False,
frozen=False)),
neck=dict(
type='FPN',
......
......@@ -31,8 +31,7 @@ class BasicBlock(nn.Module):
downsample=None,
style='pytorch',
with_cp=False,
normalize=dict(type='BN'),
frozen=False):
normalize=dict(type='BN')):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
......@@ -51,10 +50,6 @@ class BasicBlock(nn.Module):
self.dilation = dilation
assert not with_cp
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
identity = x
......@@ -85,8 +80,7 @@ class Bottleneck(nn.Module):
downsample=None,
style='pytorch',
with_cp=False,
normalize=dict(type='BN'),
frozen=False):
normalize=dict(type='BN')):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
......@@ -134,10 +128,6 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp
self.normalize = normalize
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
def _inner_forward(x):
......@@ -179,8 +169,7 @@ def make_res_layer(block,
dilation=1,
style='pytorch',
with_cp=False,
normalize=dict(type='BN'),
frozen=False):
normalize=dict(type='BN')):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -203,8 +192,7 @@ def make_res_layer(block,
downsample,
style=style,
with_cp=with_cp,
normalize=normalize,
frozen=frozen))
normalize=normalize))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
......@@ -253,9 +241,10 @@ class ResNet(nn.Module):
frozen_stages=-1,
normalize=dict(
type='BN',
eval=True,
eval_mode=True,
frozen=False),
with_cp=False):
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
......@@ -268,12 +257,13 @@ class ResNet(nn.Module):
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.with_cp = with_cp
self.is_frozen = [i <= frozen_stages for i in range(num_stages + 1)]
assert (isinstance(normalize, dict) and 'eval' in normalize
self.frozen_stages = frozen_stages
assert (isinstance(normalize, dict) and 'eval_mode' in normalize
and 'frozen' in normalize)
self.norm_eval = normalize.pop('eval')
self.norm_eval = normalize.pop('eval_mode')
self.normalize = normalize
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64
......@@ -294,13 +284,14 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp,
normalize=normalize,
frozen=self.is_frozen[i + 1])
normalize=normalize)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1)
......@@ -313,9 +304,15 @@ class ResNet(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.is_frozen[0]:
for layer in [self.conv1, stem_norm]:
for param in layer.parameters():
if self.frozen_stages >= 0:
for m in [self.conv1, stem_norm]:
for param in m.parameters():
param.requires_grad = False
def _freeze_stages(self):
for i in range(1, self.frozen_stages + 1):
m = getattr(self, 'layer{}'.format(i))
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
......@@ -326,13 +323,13 @@ class ResNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif (isinstance(m, nn.BatchNorm2d)
or isinstance(m, nn.GroupNorm)):
elif isinstance(m, (nn.BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
# zero init for last norm layer https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
if isinstance(m, (Bottleneck, BasicBlock)):
last_norm = getattr(m, m.norm_names[-1])
constant_init(last_norm, 0)
else:
......@@ -357,7 +354,7 @@ class ResNet(nn.Module):
def train(self, mode=True):
super(ResNet, self).train(mode)
if mode and self.norm_eval:
for mod in self.modules():
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(self, nn.BatchNorm2d):
mod.eval()
if isinstance(m, nn.BatchNorm2d):
m.eval()
......@@ -15,21 +15,23 @@ def build_norm_layer(cfg, num_features):
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
frozen = cfg_.pop('frozen') if 'frozen' in cfg_ else False
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
raise NotImplementedError
frozen = cfg_.pop('frozen', False)
# args name matching
if layer_type == 'GN':
if layer_type in ['GN']:
assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features)
elif layer_type == 'BN':
elif layer_type in ['BN']:
cfg_.setdefault('num_features', num_features)
cfg_.setdefault('eps', 1e-5)
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
else:
raise NotImplementedError
cfg_.setdefault('eps', 1e-5)
norm = norm_cfg[layer_type](**cfg_)
if frozen:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment