"vscode:/vscode.git/clone" did not exist on "7a685349651ebf680a85df2fb701eee36b5effde"
Commit ba3543b3 authored by pangjm's avatar pangjm
Browse files

update resnext bottleneck

parent c5553785
...@@ -77,19 +77,21 @@ class Bottleneck(nn.Module): ...@@ -77,19 +77,21 @@ class Bottleneck(nn.Module):
""" """
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.inplanes = inplanes
self.planes = planes
if style == 'pytorch': if style == 'pytorch':
conv1_stride = 1 self.conv1_stride = 1
conv2_stride = stride self.conv2_stride = stride
else: else:
conv1_stride = stride self.conv1_stride = stride
conv2_stride = 1 self.conv2_stride = 1
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False) inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
planes, planes,
planes, planes,
kernel_size=3, kernel_size=3,
stride=conv2_stride, stride=self.conv2_stride,
padding=dilation, padding=dilation,
dilation=dilation, dilation=dilation,
bias=False) bias=False)
......
...@@ -4,93 +4,43 @@ import torch.nn as nn ...@@ -4,93 +4,43 @@ import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from .resnet import ResNet from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck
class Bottleneck(nn.Module): class Bottleneck(_Bottleneck):
expansion = 4
def __init__(self, def __init__(self, *args, groups=1, base_width=4, **kwargs):
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
groups=1,
base_width=4,
style='pytorch',
with_cp=False):
"""Bottleneck block for ResNeXt. """Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer. if it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__(*args, **kwargs)
assert style in ['pytorch', 'caffe']
if groups == 1: if groups == 1:
width = planes width = self.planes
else: else:
width = math.floor(planes * (base_width / 64)) * groups width = math.floor(self.planes * (base_width / 64)) * groups
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, width, kernel_size=1, stride=conv1_stride, bias=False) self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.bn1 = nn.BatchNorm2d(width) self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
width, width,
width, width,
kernel_size=3, kernel_size=3,
stride=conv2_stride, stride=self.conv2_stride,
padding=dilation, padding=self.dilation,
dilation=dilation, dilation=self.dilation,
groups=groups, groups=groups,
bias=False) bias=False)
self.bn2 = nn.BatchNorm2d(width) self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
width, planes * self.expansion, kernel_size=1, bias=False) width, self.planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.bn3 = nn.BatchNorm2d(self.planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def make_res_layer(block, def make_res_layer(block,
...@@ -120,9 +70,9 @@ def make_res_layer(block, ...@@ -120,9 +70,9 @@ def make_res_layer(block,
block( block(
inplanes, inplanes,
planes, planes,
stride, stride=stride,
dilation, dilation=dilation,
downsample, downsample=downsample,
groups=groups, groups=groups,
base_width=base_width, base_width=base_width,
style=style, style=style,
...@@ -133,8 +83,8 @@ def make_res_layer(block, ...@@ -133,8 +83,8 @@ def make_res_layer(block,
block( block(
inplanes, inplanes,
planes, planes,
1, stride=1,
dilation, dilation=dilation,
groups=groups, groups=groups,
base_width=base_width, base_width=base_width,
style=style, style=style,
...@@ -172,12 +122,8 @@ class ResNeXt(ResNet): ...@@ -172,12 +122,8 @@ class ResNeXt(ResNet):
152: (Bottleneck, (3, 8, 36, 3)) 152: (Bottleneck, (3, 8, 36, 3))
} }
def __init__(self, def __init__(self, groups=1, base_width=4, **kwargs):
groups=1, super(ResNeXt, self).__init__(**kwargs)
base_width=4,
*args,
**kwargs):
super(ResNeXt, self).__init__(*args, **kwargs)
self.groups = groups self.groups = groups
self.base_width = base_width self.base_width = base_width
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment