Commit 3fdd041c authored by ThangVu's avatar ThangVu
Browse files

revise group norm (3)

parent dca2d841
......@@ -35,13 +35,11 @@ class BasicBlock(nn.Module):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
norm_layers = []
norm_layers.append(build_norm_layer(normalize, planes))
norm_layers.append(build_norm_layer(normalize, planes))
self.norm_names = (['gn1', 'gn2'] if normalize['type'] == 'GN'
else ['bn1', 'bn2'])
for name, layer in zip(self.norm_names, norm_layers):
self.add_module(name, layer)
# build_norm_layer return: (norm_name, norm_layer)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.add_module(self.norm1, norm1)
self.add_module(self.norm2, norm2)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
......@@ -54,11 +52,11 @@ class BasicBlock(nn.Module):
identity = x
out = self.conv1(x)
out = getattr(self, self.norm_names[0])(out)
out = getattr(self, self.norm1)(out)
out = self.relu(out)
out = self.conv2(out)
out = getattr(self, self.norm_names[1])(out)
out = getattr(self, self.norm2)(out)
if self.downsample is not None:
identity = self.downsample(x)
......@@ -110,14 +108,14 @@ class Bottleneck(nn.Module):
dilation=dilation,
bias=False)
norm_layers = []
norm_layers.append(build_norm_layer(normalize, planes))
norm_layers.append(build_norm_layer(normalize, planes))
norm_layers.append(build_norm_layer(normalize, planes*self.expansion))
self.norm_names = (['gn1', 'gn2', 'gn3'] if normalize['type'] == 'GN'
else ['bn1', 'bn2', 'bn3'])
for name, layer in zip(self.norm_names, norm_layers):
self.add_module(name, layer)
# build_norm_layer return: (norm_name, norm_layer)
self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm3, norm3 = build_norm_layer(normalize, planes*self.expansion,
postfix=3)
self.add_module(self.norm1, norm1)
self.add_module(self.norm2, norm2)
self.add_module(self.norm3, norm3)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
......@@ -134,15 +132,15 @@ class Bottleneck(nn.Module):
identity = x
out = self.conv1(x)
out = getattr(self, self.norm_names[0])(out)
out = getattr(self, self.norm1)(out)
out = self.relu(out)
out = self.conv2(out)
out = getattr(self, self.norm_names[1])(out)
out = getattr(self, self.norm2)(out)
out = self.relu(out)
out = self.conv3(out)
out = getattr(self, self.norm_names[2])(out)
out = getattr(self, self.norm3)(out)
if self.downsample is not None:
identity = self.downsample(x)
......@@ -179,7 +177,7 @@ def make_res_layer(block,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(normalize, planes * block.expansion),
build_norm_layer(normalize, planes * block.expansion)[1],
)
layers = []
......@@ -298,9 +296,9 @@ class ResNet(nn.Module):
def _make_stem_layer(self):
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
stem_norm = build_norm_layer(self.normalize, 64)
self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1'
self.add_module(self.norm_name, stem_norm)
self.stem_norm, stem_norm = build_norm_layer(self.normalize,
64, postfix=1)
self.add_module(self.stem_norm, stem_norm)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
......@@ -337,7 +335,7 @@ class ResNet(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = getattr(self, self.norm_name)(x)
x = getattr(self, self.stem_norm)(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
......
......@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels
self.norm = build_norm_layer(normalize, norm_channels)
self.norm, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm, norm)
if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.'
......@@ -73,12 +74,12 @@ class ConvModule(nn.Module):
if self.activate_last:
x = self.conv(x)
if norm and self.with_norm:
x = self.norm(x)
x = getattr(self, self.norm)(x)
if activate and self.with_activatation:
x = self.activate(x)
else:
if norm and self.with_norm:
x = self.norm(x)
x = getattr(self, self.norm)(x)
if activate and self.with_activatation:
x = self.activate(x)
x = self.conv(x)
......
import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
norm_cfg = {
# format: layer_type: (abbreation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
def build_norm_layer(cfg, num_features):
def build_norm_layer(cfg, num_features, postfix=''):
"""
cfg should contain:
type (str): identify norm layer type.
......@@ -19,22 +25,24 @@ def build_norm_layer(cfg, num_features):
layer_type = cfg_.pop('type')
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
else:
abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
frozen = cfg_.pop('frozen', False)
# args name matching
if layer_type in ['GN']:
assert 'num_groups' in cfg
cfg_.setdefault('num_channels', num_features)
elif layer_type in ['BN']:
cfg_.setdefault('num_features', num_features)
else:
raise NotImplementedError
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
norm = norm_cfg[layer_type](**cfg_)
if frozen:
for param in norm.parameters():
for param in layer.parameters():
param.requires_grad = False
return norm
return name, layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment