Commit bc5ec9bf authored by ThangVu's avatar ThangVu
Browse files

add group norm for resnext

parent 81558853
...@@ -94,6 +94,7 @@ class Bottleneck(nn.Module): ...@@ -94,6 +94,7 @@ class Bottleneck(nn.Module):
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.inplanes = inplanes self.inplanes = inplanes
self.planes = planes self.planes = planes
self.normalize = normalize
if style == 'pytorch': if style == 'pytorch':
self.conv1_stride = 1 self.conv1_stride = 1
self.conv2_stride = stride self.conv2_stride = stride
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
from .resnet import ResNet from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from ..utils import build_norm_layer
class Bottleneck(_Bottleneck): class Bottleneck(_Bottleneck):
...@@ -20,13 +21,25 @@ class Bottleneck(_Bottleneck): ...@@ -20,13 +21,25 @@ class Bottleneck(_Bottleneck):
else: else:
width = math.floor(self.planes * (base_width / 64)) * groups width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(self.normalize,
width,
postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.normalize,
width,
postfix=2)
self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion,
postfix=3)
self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3_name, norm3)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
self.inplanes, self.inplanes,
width, width,
kernel_size=1, kernel_size=1,
stride=self.conv1_stride, stride=self.conv1_stride,
bias=False) bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
width, width,
width, width,
...@@ -36,10 +49,8 @@ class Bottleneck(_Bottleneck): ...@@ -36,10 +49,8 @@ class Bottleneck(_Bottleneck):
dilation=self.dilation, dilation=self.dilation,
groups=groups, groups=groups,
bias=False) bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False) width, self.planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.planes * self.expansion)
def make_res_layer(block, def make_res_layer(block,
...@@ -51,7 +62,8 @@ def make_res_layer(block, ...@@ -51,7 +62,8 @@ def make_res_layer(block,
groups=1, groups=1,
base_width=4, base_width=4,
style='pytorch', style='pytorch',
with_cp=False): with_cp=False,
normalize=dict(type='BN')):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -61,7 +73,7 @@ def make_res_layer(block, ...@@ -61,7 +73,7 @@ def make_res_layer(block,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
bias=False), bias=False),
nn.BatchNorm2d(planes * block.expansion), build_norm_layer(normalize, planes * block.expansion)[1],
) )
layers = [] layers = []
...@@ -75,7 +87,8 @@ def make_res_layer(block, ...@@ -75,7 +87,8 @@ def make_res_layer(block,
groups=groups, groups=groups,
base_width=base_width, base_width=base_width,
style=style, style=style,
with_cp=with_cp)) with_cp=with_cp,
normalize=normalize))
inplanes = planes * block.expansion inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append( layers.append(
...@@ -87,7 +100,8 @@ def make_res_layer(block, ...@@ -87,7 +100,8 @@ def make_res_layer(block,
groups=groups, groups=groups,
base_width=base_width, base_width=base_width,
style=style, style=style,
with_cp=with_cp)) with_cp=with_cp,
normalize=normalize))
return nn.Sequential(*layers) return nn.Sequential(*layers)
...@@ -108,9 +122,9 @@ class ResNeXt(ResNet): ...@@ -108,9 +122,9 @@ class ResNeXt(ResNet):
the first 1x1 conv layer. the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze normalize (dict): dictionary to construct norm layer. Additionally,
running stats (mean and var). eval mode and gradent freezing are controlled by
bn_frozen (bool): Whether to freeze weight and bias of BN layers. eval (bool) and frozen (bool) respectively.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. memory while slowing down the training speed.
""" """
...@@ -142,8 +156,11 @@ class ResNeXt(ResNet): ...@@ -142,8 +156,11 @@ class ResNeXt(ResNet):
groups=self.groups, groups=self.groups,
base_width=self.base_width, base_width=self.base_width,
style=self.style, style=self.style,
with_cp=self.with_cp) with_cp=self.with_cp,
normalize=self.normalize)
self.inplanes = planes * self.block.expansion self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1) layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self._freeze_stages()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment