You need to sign in or sign up before continuing.
Unverified Commit 9a481d0b authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add dilation option to ResNet (#866)

* Add dilation option to ResNet

* Add a size check for replace_stride_with_dilation
parent 3b59c6ee
from collections import OrderedDict
from itertools import product
import torch import torch
from torchvision import models from torchvision import models
import unittest import unittest
...@@ -10,13 +12,34 @@ def get_available_models(): ...@@ -10,13 +12,34 @@ def get_available_models():
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def _test_model(self, name, input_shape): def _test_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature # passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.__dict__[name](num_classes=50) model = models.__dict__[name](num_classes=50)
model.eval() model.eval()
x = torch.rand(input_shape) x = torch.rand(input_shape)
out = model(x) out = model(x)
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
def _make_sliced_model(self, model, stop_layer):
layers = OrderedDict()
for name, layer in model.named_children():
layers[name] = layer
if name == stop_layer:
break
new_model = torch.nn.Sequential(layers)
return new_model
def test_resnet_dilation(self):
# TODO improve tests to also check that each layer has the right dimensionality
for i in product([False, True], [False, True], [False, True]):
model = models.__dict__["resnet50"](replace_stride_with_dilation=i)
model = self._make_sliced_model(model, stop_layer="layer4")
model.eval()
x = torch.rand(1, 3, 224, 224)
out = model(x)
f = 2 ** sum(i)
self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))
for model_name in get_available_models(): for model_name in get_available_models():
# for-loop bodies don't define scopes, so we have to save the variables # for-loop bodies don't define scopes, so we have to save the variables
......
...@@ -15,10 +15,10 @@ model_urls = { ...@@ -15,10 +15,10 @@ model_urls = {
} }
def conv3x3(in_planes, out_planes, stride=1, groups=1): def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, groups=groups, bias=False) padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
...@@ -30,12 +30,14 @@ class BasicBlock(nn.Module): ...@@ -30,12 +30,14 @@ class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, norm_layer=None): base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64: if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64') raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 # Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes) self.bn1 = norm_layer(planes)
...@@ -68,7 +70,7 @@ class Bottleneck(nn.Module): ...@@ -68,7 +70,7 @@ class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, norm_layer=None): base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -76,7 +78,7 @@ class Bottleneck(nn.Module): ...@@ -76,7 +78,7 @@ class Bottleneck(nn.Module):
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 # Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width) self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width) self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups) self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion) self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion)
...@@ -110,12 +112,22 @@ class Bottleneck(nn.Module): ...@@ -110,12 +112,22 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, norm_layer=None): groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64 self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups self.groups = groups
self.base_width = width_per_group self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
...@@ -123,10 +135,13 @@ class ResNet(nn.Module): ...@@ -123,10 +135,13 @@ class ResNet(nn.Module):
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) dilate=replace_stride_with_dilation[0])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
...@@ -147,10 +162,13 @@ class ResNet(nn.Module): ...@@ -147,10 +162,13 @@ class ResNet(nn.Module):
elif isinstance(m, BasicBlock): elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
if norm_layer is None: norm_layer = self._norm_layer
norm_layer = nn.BatchNorm2d
downsample = None downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride), conv1x1(self.inplanes, planes * block.expansion, stride),
...@@ -159,11 +177,12 @@ class ResNet(nn.Module): ...@@ -159,11 +177,12 @@ class ResNet(nn.Module):
layers = [] layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, norm_layer)) self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for _ in range(1, blocks): for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups, layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, norm_layer=norm_layer)) base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers) return nn.Sequential(*layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment