Commit ba3543b3 authored by pangjm's avatar pangjm
Browse files

update resnext bottleneck

parent c5553785
......@@ -77,19 +77,21 @@ class Bottleneck(nn.Module):
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
self.inplanes = inplanes
self.planes = planes
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
self.conv1_stride = 1
self.conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1_stride = stride
self.conv2_stride = 1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=conv2_stride,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
......
......@@ -4,93 +4,43 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck
class Bottleneck(nn.Module):
expansion = 4
class Bottleneck(_Bottleneck):
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
groups=1,
base_width=4,
style='pytorch',
with_cp=False):
def __init__(self, *args, groups=1, base_width=4, **kwargs):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
super(Bottleneck, self).__init__(*args, **kwargs)
if groups == 1:
width = planes
width = self.planes
else:
width = math.floor(planes * (base_width / 64)) * groups
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
width = math.floor(self.planes * (base_width / 64)) * groups
self.conv1 = nn.Conv2d(
inplanes, width, kernel_size=1, stride=conv1_stride, bias=False)
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=conv2_stride,
padding=dilation,
dilation=dilation,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(
width, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
width, self.planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.planes * self.expansion)
def make_res_layer(block,
......@@ -120,9 +70,9 @@ def make_res_layer(block,
block(
inplanes,
planes,
stride,
dilation,
downsample,
stride=stride,
dilation=dilation,
downsample=downsample,
groups=groups,
base_width=base_width,
style=style,
......@@ -133,8 +83,8 @@ def make_res_layer(block,
block(
inplanes,
planes,
1,
dilation,
stride=1,
dilation=dilation,
groups=groups,
base_width=base_width,
style=style,
......@@ -172,12 +122,8 @@ class ResNeXt(ResNet):
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
*args,
**kwargs):
super(ResNeXt, self).__init__(*args, **kwargs)
def __init__(self, groups=1, base_width=4, **kwargs):
super(ResNeXt, self).__init__(**kwargs)
self.groups = groups
self.base_width = base_width
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment