Commit 574a920a authored by ThangVu's avatar ThangVu
Browse files

revise group norm (4)

parent 3fdd041c
...@@ -35,9 +35,8 @@ class BasicBlock(nn.Module): ...@@ -35,9 +35,8 @@ class BasicBlock(nn.Module):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
# build_norm_layer return: (norm_name, norm_layer) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.add_module(self.norm1, norm1) self.add_module(self.norm1, norm1)
self.add_module(self.norm2, norm2) self.add_module(self.norm2, norm2)
...@@ -48,15 +47,23 @@ class BasicBlock(nn.Module): ...@@ -48,15 +47,23 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x): def forward(self, x):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm1)(out) out = self.norm1(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm2)(out) out = self.norm2(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -108,14 +115,14 @@ class Bottleneck(nn.Module): ...@@ -108,14 +115,14 @@ class Bottleneck(nn.Module):
dilation=dilation, dilation=dilation,
bias=False) bias=False)
# build_norm_layer return: (norm_name, norm_layer) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2) self.norm3_name, norm3 = build_norm_layer(normalize,
self.norm3, norm3 = build_norm_layer(normalize, planes*self.expansion, planes * self.expansion,
postfix=3) postfix=3)
self.add_module(self.norm1, norm1) self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2, norm2) self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3, norm3) self.add_module(self.norm3_name, norm3)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False) planes, planes * self.expansion, kernel_size=1, bias=False)
...@@ -126,21 +133,33 @@ class Bottleneck(nn.Module): ...@@ -126,21 +133,33 @@ class Bottleneck(nn.Module):
self.with_cp = with_cp self.with_cp = with_cp
self.normalize = normalize self.normalize = normalize
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
@property
def norm3(self):
return getattr(self, self.norm3_name)
def forward(self, x): def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm1)(out) out = self.norm1(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm2)(out) out = self.norm2(out)
out = self.relu(out) out = self.relu(out)
out = self.conv3(out) out = self.conv3(out)
out = getattr(self, self.norm3)(out) out = self.norm3(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -293,17 +312,21 @@ class ResNet(nn.Module): ...@@ -293,17 +312,21 @@ class ResNet(nn.Module):
self.feat_dim = self.block.expansion * 64 * 2**( self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1) len(self.stage_blocks) - 1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self): def _make_stem_layer(self):
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.stem_norm, stem_norm = build_norm_layer(self.normalize, self.norm1_name, norm1 = build_norm_layer(self.normalize,
64, postfix=1) 64, postfix=1)
self.add_module(self.stem_norm, stem_norm) self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
for m in [self.conv1, stem_norm]: for m in [self.conv1, self.norm1]:
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -327,15 +350,16 @@ class ResNet(nn.Module): ...@@ -327,15 +350,16 @@ class ResNet(nn.Module):
# zero init for last norm layer https://arxiv.org/abs/1706.02677 # zero init for last norm layer https://arxiv.org/abs/1706.02677
if self.zero_init_residual: if self.zero_init_residual:
for m in self.modules(): for m in self.modules():
if isinstance(m, (Bottleneck, BasicBlock)): if isinstance(m, Bottleneck):
last_norm = getattr(m, m.norm_names[-1]) constant_init(m.norm3, 0)
constant_init(last_norm, 0) elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = getattr(self, self.stem_norm)(x) x = self.norm1(x)
x = self.relu(x) x = self.relu(x)
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
......
...@@ -53,8 +53,8 @@ class ConvModule(nn.Module): ...@@ -53,8 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels norm_channels = out_channels if self.activate_last else in_channels
self.norm, norm = build_norm_layer(normalize, norm_channels) self.norm_name, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm, norm) self.add_module(self.norm_name, norm)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -64,6 +64,10 @@ class ConvModule(nn.Module): ...@@ -64,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init # Default using msra init
self.init_weights() self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self): def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation nonlinearity = 'relu' if self.activation is None else self.activation
kaiming_init(self.conv, nonlinearity=nonlinearity) kaiming_init(self.conv, nonlinearity=nonlinearity)
...@@ -74,12 +78,12 @@ class ConvModule(nn.Module): ...@@ -74,12 +78,12 @@ class ConvModule(nn.Module):
if self.activate_last: if self.activate_last:
x = self.conv(x) x = self.conv(x)
if norm and self.with_norm: if norm and self.with_norm:
x = getattr(self, self.norm)(x) x = self.norm(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
else: else:
if norm and self.with_norm: if norm and self.with_norm:
x = getattr(self, self.norm)(x) x = self.norm(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
x = self.conv(x) x = self.conv(x)
......
...@@ -11,13 +11,22 @@ norm_cfg = { ...@@ -11,13 +11,22 @@ norm_cfg = {
def build_norm_layer(cfg, num_features, postfix=''): def build_norm_layer(cfg, num_features, postfix=''):
""" """ Build normalization layer
cfg should contain:
Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type. type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer. layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode of norm layer, it is helpful to set frozen mode
in backbone's norms. in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
""" """
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment