Commit 55a4feb5 authored by thangvu's avatar thangvu
Browse files

revise norm order in backbone resblocks and minor fix

parent bc5ec9bf
...@@ -33,15 +33,16 @@ class BasicBlock(nn.Module): ...@@ -33,15 +33,16 @@ class BasicBlock(nn.Module):
with_cp=False, with_cp=False,
normalize=dict(type='BN')): normalize=dict(type='BN')):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.add_module(self.norm1_name, norm1) self.add_module(self.norm1_name, norm1)
self.conv2 = conv3x3(planes, planes)
self.add_module(self.norm2_name, norm2) self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
...@@ -101,12 +102,20 @@ class Bottleneck(nn.Module): ...@@ -101,12 +102,20 @@ class Bottleneck(nn.Module):
else: else:
self.conv1_stride = stride self.conv1_stride = stride
self.conv2_stride = 1 self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(normalize,
planes * self.expansion,
postfix=3)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, inplanes,
planes, planes,
kernel_size=1, kernel_size=1,
stride=self.conv1_stride, stride=self.conv1_stride,
bias=False) bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
planes, planes,
planes, planes,
...@@ -115,18 +124,11 @@ class Bottleneck(nn.Module): ...@@ -115,18 +124,11 @@ class Bottleneck(nn.Module):
padding=dilation, padding=dilation,
dilation=dilation, dilation=dilation,
bias=False) bias=False)
self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(normalize,
planes * self.expansion,
postfix=3)
self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2_name, norm2) self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3_name, norm3)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False) planes, planes * self.expansion, kernel_size=1, bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
......
...@@ -30,9 +30,6 @@ class Bottleneck(_Bottleneck): ...@@ -30,9 +30,6 @@ class Bottleneck(_Bottleneck):
self.norm3_name, norm3 = build_norm_layer(self.normalize, self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion, self.planes * self.expansion,
postfix=3) postfix=3)
self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3_name, norm3)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
self.inplanes, self.inplanes,
...@@ -40,6 +37,7 @@ class Bottleneck(_Bottleneck): ...@@ -40,6 +37,7 @@ class Bottleneck(_Bottleneck):
kernel_size=1, kernel_size=1,
stride=self.conv1_stride, stride=self.conv1_stride,
bias=False) bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
width, width,
width, width,
...@@ -49,8 +47,10 @@ class Bottleneck(_Bottleneck): ...@@ -49,8 +47,10 @@ class Bottleneck(_Bottleneck):
dilation=self.dilation, dilation=self.dilation,
groups=groups, groups=groups,
bias=False) bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False) width, self.planes * self.expansion, kernel_size=1, bias=False)
self.add_module(self.norm3_name, norm3)
def make_res_layer(block, def make_res_layer(block,
......
...@@ -2,7 +2,7 @@ import torch.nn as nn ...@@ -2,7 +2,7 @@ import torch.nn as nn
norm_cfg = { norm_cfg = {
# format: layer_type: (abbreation, module) # format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d), 'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None), 'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm), 'GN': ('gn', nn.GroupNorm),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment