Commit 574a920a authored by ThangVu's avatar ThangVu
Browse files

revise group norm (4)

parent 3fdd041c
...@@ -35,9 +35,8 @@ class BasicBlock(nn.Module): ...@@ -35,9 +35,8 @@ class BasicBlock(nn.Module):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
# build_norm_layer return: (norm_name, norm_layer) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.add_module(self.norm1, norm1) self.add_module(self.norm1, norm1)
self.add_module(self.norm2, norm2) self.add_module(self.norm2, norm2)
...@@ -48,15 +47,23 @@ class BasicBlock(nn.Module): ...@@ -48,15 +47,23 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x): def forward(self, x):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm1)(out) out = self.norm1(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm2)(out) out = self.norm2(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -108,14 +115,14 @@ class Bottleneck(nn.Module): ...@@ -108,14 +115,14 @@ class Bottleneck(nn.Module):
dilation=dilation, dilation=dilation,
bias=False) bias=False)
# build_norm_layer return: (norm_name, norm_layer) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2) self.norm3_name, norm3 = build_norm_layer(normalize,
self.norm3, norm3 = build_norm_layer(normalize, planes*self.expansion, planes * self.expansion,
postfix=3) postfix=3)
self.add_module(self.norm1, norm1) self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2, norm2) self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3, norm3) self.add_module(self.norm3_name, norm3)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False) planes, planes * self.expansion, kernel_size=1, bias=False)
...@@ -126,21 +133,33 @@ class Bottleneck(nn.Module): ...@@ -126,21 +133,33 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp self.with_cp = with_cp
self.normalize = normalize self.normalize = normalize
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
@property
def norm3(self):
return getattr(self, self.norm3_name)
def forward(self, x): def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm1)(out) out = self.norm1(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm2)(out) out = self.norm2(out)
out = self.relu(out) out = self.relu(out)
out = self.conv3(out) out = self.conv3(out)
out = getattr(self, self.norm3)(out) out = self.norm3(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -293,17 +312,21 @@ class ResNet(nn.Module): ...@@ -293,17 +312,21 @@ class ResNet(nn.Module):
self.feat_dim = self.block.expansion * 64 * 2**( self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1) len(self.stage_blocks) - 1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self): def _make_stem_layer(self):
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.stem_norm, stem_norm = build_norm_layer(self.normalize, self.norm1_name, norm1 = build_norm_layer(self.normalize,
64, postfix=1) 64, postfix=1)
self.add_module(self.stem_norm, stem_norm) self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
for m in [self.conv1, stem_norm]: for m in [self.conv1, self.norm1]:
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -327,15 +350,16 @@ class ResNet(nn.Module): ...@@ -327,15 +350,16 @@ class ResNet(nn.Module):
# zero init for last norm layer https://arxiv.org/abs/1706.02677 # zero init for last norm layer https://arxiv.org/abs/1706.02677
if self.zero_init_residual: if self.zero_init_residual:
for m in self.modules(): for m in self.modules():
if isinstance(m, (Bottleneck, BasicBlock)): if isinstance(m, Bottleneck):
last_norm = getattr(m, m.norm_names[-1]) constant_init(m.norm3, 0)
constant_init(last_norm, 0) elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = getattr(self, self.stem_norm)(x) x = self.norm1(x)
x = self.relu(x) x = self.relu(x)
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
......
...@@ -53,8 +53,8 @@ class ConvModule(nn.Module): ...@@ -53,8 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels norm_channels = out_channels if self.activate_last else in_channels
self.norm, norm = build_norm_layer(normalize, norm_channels) self.norm_name, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm, norm) self.add_module(self.norm_name, norm)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -64,6 +64,10 @@ class ConvModule(nn.Module): ...@@ -64,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init # Default using msra init
self.init_weights() self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self): def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation nonlinearity = 'relu' if self.activation is None else self.activation
kaiming_init(self.conv, nonlinearity=nonlinearity) kaiming_init(self.conv, nonlinearity=nonlinearity)
...@@ -74,12 +78,12 @@ class ConvModule(nn.Module): ...@@ -74,12 +78,12 @@ class ConvModule(nn.Module):
if self.activate_last: if self.activate_last:
x = self.conv(x) x = self.conv(x)
if norm and self.with_norm: if norm and self.with_norm:
x = getattr(self, self.norm)(x) x = self.norm(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
else: else:
if norm and self.with_norm: if norm and self.with_norm:
x = getattr(self, self.norm)(x) x = self.norm(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
x = self.conv(x) x = self.conv(x)
......
...@@ -11,13 +11,22 @@ norm_cfg = { ...@@ -11,13 +11,22 @@ norm_cfg = {
def build_norm_layer(cfg, num_features, postfix=''): def build_norm_layer(cfg, num_features, postfix=''):
""" """ Build normalization layer
cfg should contain:
type (str): identify norm layer type. Args:
layer args: args needed to instantiate a norm layer. cfg (dict): cfg should contain:
frozen (bool): [optional] whether stop gradient updates type (str): identify norm layer type.
of norm layer, it is helpful to set frozen mode layer args: args needed to instantiate a norm layer.
in backbone's norms. frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
""" """
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment