Commit bc5ec9bf authored by ThangVu's avatar ThangVu
Browse files

add group norm for resnext

parent 81558853
......@@ -94,6 +94,7 @@ class Bottleneck(nn.Module):
assert style in ['pytorch', 'caffe']
self.inplanes = inplanes
self.planes = planes
self.normalize = normalize
if style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
......
......@@ -4,6 +4,7 @@ import torch.nn as nn
from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck
from ..utils import build_norm_layer
class Bottleneck(_Bottleneck):
......@@ -20,13 +21,25 @@ class Bottleneck(_Bottleneck):
else:
width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(self.normalize,
width,
postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.normalize,
width,
postfix=2)
self.norm3_name, norm3 = build_norm_layer(self.normalize,
self.planes * self.expansion,
postfix=3)
self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2_name, norm2)
self.add_module(self.norm3_name, norm3)
self.conv1 = nn.Conv2d(
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(
width,
width,
......@@ -36,10 +49,8 @@ class Bottleneck(_Bottleneck):
dilation=self.dilation,
groups=groups,
bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(
width, self.planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.planes * self.expansion)
def make_res_layer(block,
......@@ -51,7 +62,8 @@ def make_res_layer(block,
groups=1,
base_width=4,
style='pytorch',
with_cp=False):
with_cp=False,
normalize=dict(type='BN')):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -61,7 +73,7 @@ def make_res_layer(block,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
build_norm_layer(normalize, planes * block.expansion)[1],
)
layers = []
......@@ -75,7 +87,8 @@ def make_res_layer(block,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp))
with_cp=with_cp,
normalize=normalize))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
......@@ -87,7 +100,8 @@ def make_res_layer(block,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp))
with_cp=with_cp,
normalize=normalize))
return nn.Sequential(*layers)
......@@ -108,9 +122,9 @@ class ResNeXt(ResNet):
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
normalize (dict): dictionary to construct norm layer. Additionally,
eval mode and gradent freezing are controlled by
eval (bool) and frozen (bool) respectively.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
......@@ -142,8 +156,11 @@ class ResNeXt(ResNet):
groups=self.groups,
base_width=self.base_width,
style=self.style,
with_cp=self.with_cp)
with_cp=self.with_cp,
normalize=self.normalize)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment