"vscode:/vscode.git/clone" did not exist on "929870d3ad5af4ed39567f25c3e7433d84924cd7"
Commit dca2d841 authored by ThangVu's avatar ThangVu
Browse files

revise group norm (2)

parent f64c9561
...@@ -12,7 +12,7 @@ model = dict( ...@@ -12,7 +12,7 @@ model = dict(
normalize=dict( normalize=dict(
type='GN', type='GN',
num_groups=32, num_groups=32,
eval=False, eval_mode=False,
frozen=False)), frozen=False)),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -31,8 +31,7 @@ class BasicBlock(nn.Module): ...@@ -31,8 +31,7 @@ class BasicBlock(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN'), normalize=dict(type='BN')):
frozen=False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
...@@ -51,10 +50,6 @@ class BasicBlock(nn.Module): ...@@ -51,10 +50,6 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x): def forward(self, x):
identity = x identity = x
...@@ -85,8 +80,7 @@ class Bottleneck(nn.Module): ...@@ -85,8 +80,7 @@ class Bottleneck(nn.Module):
downsample=None, downsample=None,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN'), normalize=dict(type='BN')):
frozen=False):
"""Bottleneck block for ResNet. """Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer. if it is "caffe", the stride-two layer is the first 1x1 conv layer.
...@@ -134,10 +128,6 @@ class Bottleneck(nn.Module): ...@@ -134,10 +128,6 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp self.with_cp = with_cp
self.normalize = normalize self.normalize = normalize
if frozen:
for param in self.parameters():
param.requires_grad = False
def forward(self, x): def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
...@@ -179,8 +169,7 @@ def make_res_layer(block, ...@@ -179,8 +169,7 @@ def make_res_layer(block,
dilation=1, dilation=1,
style='pytorch', style='pytorch',
with_cp=False, with_cp=False,
normalize=dict(type='BN'), normalize=dict(type='BN')):
frozen=False):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -203,8 +192,7 @@ def make_res_layer(block, ...@@ -203,8 +192,7 @@ def make_res_layer(block,
downsample, downsample,
style=style, style=style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize, normalize=normalize))
frozen=frozen))
inplanes = planes * block.expansion inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append( layers.append(
...@@ -253,9 +241,10 @@ class ResNet(nn.Module): ...@@ -253,9 +241,10 @@ class ResNet(nn.Module):
frozen_stages=-1, frozen_stages=-1,
normalize=dict( normalize=dict(
type='BN', type='BN',
eval=True, eval_mode=True,
frozen=False), frozen=False),
with_cp=False): with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth)) raise KeyError('invalid depth {} for resnet'.format(depth))
...@@ -268,12 +257,13 @@ class ResNet(nn.Module): ...@@ -268,12 +257,13 @@ class ResNet(nn.Module):
self.out_indices = out_indices self.out_indices = out_indices
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
self.style = style self.style = style
self.with_cp = with_cp self.frozen_stages = frozen_stages
self.is_frozen = [i <= frozen_stages for i in range(num_stages + 1)] assert (isinstance(normalize, dict) and 'eval_mode' in normalize
assert (isinstance(normalize, dict) and 'eval' in normalize
and 'frozen' in normalize) and 'frozen' in normalize)
self.norm_eval = normalize.pop('eval') self.norm_eval = normalize.pop('eval_mode')
self.normalize = normalize self.normalize = normalize
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth] self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64 self.inplanes = 64
...@@ -294,13 +284,14 @@ class ResNet(nn.Module): ...@@ -294,13 +284,14 @@ class ResNet(nn.Module):
dilation=dilation, dilation=dilation,
style=self.style, style=self.style,
with_cp=with_cp, with_cp=with_cp,
normalize=normalize, normalize=normalize)
frozen=self.is_frozen[i + 1])
self.inplanes = planes * self.block.expansion self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1) layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * 64 * 2**( self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1) len(self.stage_blocks) - 1)
...@@ -313,11 +304,17 @@ class ResNet(nn.Module): ...@@ -313,11 +304,17 @@ class ResNet(nn.Module):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.is_frozen[0]: if self.frozen_stages >= 0:
for layer in [self.conv1, stem_norm]: for m in [self.conv1, stem_norm]:
for param in layer.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
def _freeze_stages(self):
for i in range(1, self.frozen_stages + 1):
m = getattr(self, 'layer{}'.format(i))
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
...@@ -326,15 +323,15 @@ class ResNet(nn.Module): ...@@ -326,15 +323,15 @@ class ResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m) kaiming_init(m)
elif (isinstance(m, nn.BatchNorm2d) elif isinstance(m, (nn.BatchNorm, nn.GroupNorm)):
or isinstance(m, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
# zero init for last norm layer https://arxiv.org/abs/1706.02677 # zero init for last norm layer https://arxiv.org/abs/1706.02677
for m in self.modules(): if self.zero_init_residual:
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): for m in self.modules():
last_norm = getattr(m, m.norm_names[-1]) if isinstance(m, (Bottleneck, BasicBlock)):
constant_init(last_norm, 0) last_norm = getattr(m, m.norm_names[-1])
constant_init(last_norm, 0)
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
...@@ -357,7 +354,7 @@ class ResNet(nn.Module): ...@@ -357,7 +354,7 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
if mode and self.norm_eval: if mode and self.norm_eval:
for mod in self.modules(): for m in self.modules():
# trick: eval have effect on BatchNorm only # trick: eval have effect on BatchNorm only
if isinstance(self, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
mod.eval() m.eval()
...@@ -15,21 +15,23 @@ def build_norm_layer(cfg, num_features): ...@@ -15,21 +15,23 @@ def build_norm_layer(cfg, num_features):
""" """
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
frozen = cfg_.pop('frozen') if 'frozen' in cfg_ else False if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
raise NotImplementedError
frozen = cfg_.pop('frozen', False)
# args name matching # args name matching
if layer_type == 'GN': if layer_type in ['GN']:
assert 'num_groups' in cfg assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features) cfg_.setdefault('num_channels', num_features)
elif layer_type == 'BN': elif layer_type in ['BN']:
cfg_.setdefault('num_features', num_features) cfg_.setdefault('num_features', num_features)
cfg_.setdefault('eps', 1e-5) else:
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
raise NotImplementedError raise NotImplementedError
cfg_.setdefault('eps', 1e-5)
norm = norm_cfg[layer_type](**cfg_) norm = norm_cfg[layer_type](**cfg_)
if frozen: if frozen:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment