Commit 3fdd041c authored by ThangVu's avatar ThangVu
Browse files

revise group norm (3)

parent dca2d841
...@@ -35,13 +35,11 @@ class BasicBlock(nn.Module): ...@@ -35,13 +35,11 @@ class BasicBlock(nn.Module):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
norm_layers = [] # build_norm_layer return: (norm_name, norm_layer)
norm_layers.append(build_norm_layer(normalize, planes)) self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
norm_layers.append(build_norm_layer(normalize, planes)) self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm_names = (['gn1', 'gn2'] if normalize['type'] == 'GN' self.add_module(self.norm1, norm1)
else ['bn1', 'bn2']) self.add_module(self.norm2, norm2)
for name, layer in zip(self.norm_names, norm_layers):
self.add_module(name, layer)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes) self.conv2 = conv3x3(planes, planes)
...@@ -54,11 +52,11 @@ class BasicBlock(nn.Module): ...@@ -54,11 +52,11 @@ class BasicBlock(nn.Module):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm_names[0])(out) out = getattr(self, self.norm1)(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm_names[1])(out) out = getattr(self, self.norm2)(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -110,14 +108,14 @@ class Bottleneck(nn.Module): ...@@ -110,14 +108,14 @@ class Bottleneck(nn.Module):
dilation=dilation, dilation=dilation,
bias=False) bias=False)
norm_layers = [] # build_norm_layer return: (norm_name, norm_layer)
norm_layers.append(build_norm_layer(normalize, planes)) self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
norm_layers.append(build_norm_layer(normalize, planes)) self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
norm_layers.append(build_norm_layer(normalize, planes*self.expansion)) self.norm3, norm3 = build_norm_layer(normalize, planes*self.expansion,
self.norm_names = (['gn1', 'gn2', 'gn3'] if normalize['type'] == 'GN' postfix=3)
else ['bn1', 'bn2', 'bn3']) self.add_module(self.norm1, norm1)
for name, layer in zip(self.norm_names, norm_layers): self.add_module(self.norm2, norm2)
self.add_module(name, layer) self.add_module(self.norm3, norm3)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False) planes, planes * self.expansion, kernel_size=1, bias=False)
...@@ -134,15 +132,15 @@ class Bottleneck(nn.Module): ...@@ -134,15 +132,15 @@ class Bottleneck(nn.Module):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = getattr(self, self.norm_names[0])(out) out = getattr(self, self.norm1)(out)
out = self.relu(out) out = self.relu(out)
out = self.conv2(out) out = self.conv2(out)
out = getattr(self, self.norm_names[1])(out) out = getattr(self, self.norm2)(out)
out = self.relu(out) out = self.relu(out)
out = self.conv3(out) out = self.conv3(out)
out = getattr(self, self.norm_names[2])(out) out = getattr(self, self.norm3)(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
...@@ -179,7 +177,7 @@ def make_res_layer(block, ...@@ -179,7 +177,7 @@ def make_res_layer(block,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
bias=False), bias=False),
build_norm_layer(normalize, planes * block.expansion), build_norm_layer(normalize, planes * block.expansion)[1],
) )
layers = [] layers = []
...@@ -298,9 +296,9 @@ class ResNet(nn.Module): ...@@ -298,9 +296,9 @@ class ResNet(nn.Module):
def _make_stem_layer(self): def _make_stem_layer(self):
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(self.normalize, 64) self.stem_norm, stem_norm = build_norm_layer(self.normalize,
self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1' 64, postfix=1)
self.add_module(self.norm_name, stem_norm) self.add_module(self.stem_norm, stem_norm)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
...@@ -337,7 +335,7 @@ class ResNet(nn.Module): ...@@ -337,7 +335,7 @@ class ResNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = getattr(self, self.norm_name)(x) x = getattr(self, self.stem_norm)(x)
x = self.relu(x) x = self.relu(x)
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
......
...@@ -53,7 +53,8 @@ class ConvModule(nn.Module): ...@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels norm_channels = out_channels if self.activate_last else in_channels
self.norm = build_norm_layer(normalize, norm_channels) self.norm, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm, norm)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -73,12 +74,12 @@ class ConvModule(nn.Module): ...@@ -73,12 +74,12 @@ class ConvModule(nn.Module):
if self.activate_last: if self.activate_last:
x = self.conv(x) x = self.conv(x)
if norm and self.with_norm: if norm and self.with_norm:
x = self.norm(x) x = getattr(self, self.norm)(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
else: else:
if norm and self.with_norm: if norm and self.with_norm:
x = self.norm(x) x = getattr(self, self.norm)(x)
if activate and self.with_activatation: if activate and self.with_activatation:
x = self.activate(x) x = self.activate(x)
x = self.conv(x) x = self.conv(x)
......
import torch.nn as nn import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm} norm_cfg = {
# format: layer_type: (abbreation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
def build_norm_layer(cfg, num_features): def build_norm_layer(cfg, num_features, postfix=''):
""" """
cfg should contain: cfg should contain:
type (str): identify norm layer type. type (str): identify norm layer type.
...@@ -19,22 +25,24 @@ def build_norm_layer(cfg, num_features): ...@@ -19,22 +25,24 @@ def build_norm_layer(cfg, num_features):
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in norm_cfg: if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type)) raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None: else:
abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError raise NotImplementedError
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
frozen = cfg_.pop('frozen', False) frozen = cfg_.pop('frozen', False)
# args name matching
if layer_type in ['GN']:
assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features)
elif layer_type in ['BN']:
cfg_.setdefault('num_features', num_features)
else:
raise NotImplementedError
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
norm = norm_cfg[layer_type](**cfg_)
if frozen: if frozen:
for param in norm.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
return norm
return name, layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment