Unverified Commit b872eb8c authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

ResNeSt plus (#256)

parent 5a1e3fbc
from .resnet import *
from .resnest import *
from .resnext import *
from .resnet_variants import *
from .wideresnet import *
from .xception import *
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""ResNeSt models"""
import torch
from .resnet import ResNet, Bottleneck
from ..model_store import get_model_file
__all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269']
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=32, avg_down=True,
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest50', root=root)), strict=False)
return model
def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=64, avg_down=True,
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest101', root=root)), strict=False)
return model
def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 24, 36, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=64, avg_down=True,
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest152', root=root)), strict=False)
return model
def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 30, 48, 8],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=64, avg_down=True,
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest269', root=root)), strict=False)
return model
def resnest50_fast(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=32, avg_down=True,
avd=True, avd_first=True, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest50fast', root=root)), strict=False)
return model
def resnest101_fast(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=64, avg_down=True,
avd=True, avd_first=True, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnest101fast', root=root)), strict=False)
return model
"""Dilated ResNet""" ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""ResNet variants"""
import math import math
import torch import torch
import torch.utils.model_zoo as model_zoo
import torch.nn as nn import torch.nn as nn
from ..nn import GlobalAvgPool2d from ...nn import SplAtConv2d, DropBlock2D, GlobalAvgPool2d, RFConv2d
from ..models.model_store import get_model_file from ..model_store import get_model_file
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'BasicBlock', 'Bottleneck']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
"""ResNet BasicBlock
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=previous_dilation, dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
__all__ = ['ResNet', 'Bottleneck',
'resnet50', 'resnet101', 'resnet152']
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
"""ResNet Bottleneck """ResNet Bottleneck
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, def __init__(self, inplanes, planes, stride=1, downsample=None,
downsample=None, previous_dilation=1, norm_layer=None): radix=1, cardinality=1, bottleneck_width=64,
avd=False, avd_first=False, dilation=1, is_first=False,
rectified_conv=False, rectify_avg=False,
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.bn1 = norm_layer(planes) self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
self.avd_first = avd_first
if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1
if dropblock_prob > 0.0:
self.dropblock1 = DropBlock2D(dropblock_prob, 3)
if radix == 1:
self.dropblock2 = DropBlock2D(dropblock_prob, 3)
self.dropblock3 = DropBlock2D(dropblock_prob, 3)
if radix > 1:
self.conv2 = SplAtConv2d(
group_width, group_width, kernel_size=3,
stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False,
radix=radix, rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif rectified_conv:
self.conv2 = RFConv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
else:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride, group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False) padding=dilation, dilation=dilation,
self.bn2 = norm_layer(planes) groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.conv3 = nn.Conv2d( self.conv3 = nn.Conv2d(
planes, planes * 4, kernel_size=1, bias=False) group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4) self.bn3 = norm_layer(planes*4)
if last_gamma:
from torch.nn.init import zeros_
zeros_(self.bn3.weight)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.dilation = dilation self.dilation = dilation
self.stride = stride self.stride = stride
def _sum_each(self, x, y):
assert(len(x) == len(y))
z = []
for i in range(len(x)):
z.append(x[i]+y[i])
return z
def forward(self, x): def forward(self, x):
residual = x residual = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
if self.dropblock_prob > 0.0:
out = self.dropblock1(out)
out = self.relu(out) out = self.relu(out)
if self.avd and self.avd_first:
out = self.avd_layer(out)
out = self.conv2(out) out = self.conv2(out)
if self.radix == 1:
out = self.bn2(out) out = self.bn2(out)
if self.dropblock_prob > 0.0:
out = self.dropblock2(out)
out = self.relu(out) out = self.relu(out)
if self.avd and not self.avd_first:
out = self.avd_layer(out)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
if self.dropblock_prob > 0.0:
out = self.dropblock3(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) residual = self.downsample(x)
...@@ -109,9 +115,8 @@ class Bottleneck(nn.Module): ...@@ -109,9 +115,8 @@ class Bottleneck(nn.Module):
return out return out
class ResNet(nn.Module): class ResNet(nn.Module):
"""Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. """ResNet Variants
Parameters Parameters
---------- ----------
...@@ -135,44 +140,73 @@ class ResNet(nn.Module): ...@@ -135,44 +140,73 @@ class ResNet(nn.Module):
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
""" """
# pylint: disable=unused-variable # pylint: disable=unused-variable
def __init__(self, block, layers, num_classes=1000, dilated=False, multi_grid=False, def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
deep_base=True, norm_layer=nn.BatchNorm2d): num_classes=1000, dilated=False, dilation=1,
self.inplanes = 128 if deep_base else 64 deep_stem=False, stem_width=64, avg_down=False,
rectified_conv=False, rectify_avg=False,
avd=False, avd_first=False,
final_drop=0.0, dropblock_prob=0,
last_gamma=False, norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width*2 if deep_stem else 64
self.avg_down = avg_down
self.last_gamma = last_gamma
# ResNeSt params
self.radix = radix
self.avd = avd
self.avd_first = avd_first
super(ResNet, self).__init__() super(ResNet, self).__init__()
if deep_base: self.rectified_conv = rectified_conv
self.rectify_avg = rectify_avg
if rectified_conv:
conv_layer = RFConv2d
else:
conv_layer = nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
if deep_stem:
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
norm_layer(64), norm_layer(stem_width),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
norm_layer(64), norm_layer(stem_width),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
) )
else: else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False) bias=False, **conv_kwargs)
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated: if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilation=2, norm_layer=norm_layer) dilation=2, norm_layer=norm_layer,
if multi_grid: dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer, dilation=4, norm_layer=norm_layer,
multi_grid=True) dropblock_prob=dropblock_prob)
else: elif dilation==2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilation=1, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer) dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
else: else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
norm_layer=norm_layer) norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
norm_layer=norm_layer) norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.avgpool = GlobalAvgPool2d() self.avgpool = GlobalAvgPool2d()
self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
...@@ -183,37 +217,58 @@ class ResNet(nn.Module): ...@@ -183,37 +217,58 @@ class ResNet(nn.Module):
m.weight.data.fill_(1) m.weight.data.fill_(1)
m.bias.data.zero_() m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False): def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
dropblock_prob=0.0, is_first=True):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( down_layers = []
nn.Conv2d(self.inplanes, planes * block.expansion, if self.avg_down:
kernel_size=1, stride=stride, bias=False), if dilation == 1:
norm_layer(planes * block.expansion), down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
) ceil_mode=True, count_include_pad=False))
else:
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
ceil_mode=True, count_include_pad=False))
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False))
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)
layers = [] layers = []
multi_dilations = [4, 8, 16] if dilation == 1 or dilation == 2:
if multi_grid: layers.append(block(self.inplanes, planes, stride, downsample=downsample,
layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0], radix=self.radix, cardinality=self.cardinality,
downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) bottleneck_width=self.bottleneck_width,
elif dilation == 1 or dilation == 2: avd=self.avd, avd_first=self.avd_first,
layers.append(block(self.inplanes, planes, stride, dilation=1, dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4: elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2, layers.append(block(self.inplanes, planes, stride, downsample=downsample,
downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else: else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation)) raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
if multi_grid: layers.append(block(self.inplanes, planes,
layers.append(block(self.inplanes, planes, dilation=multi_dilations[i], radix=self.radix, cardinality=self.cardinality,
previous_dilation=dilation, norm_layer=norm_layer)) bottleneck_width=self.bottleneck_width,
else: avd=self.avd, avd_first=self.avd_first,
layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, dilation=dilation, rectified_conv=self.rectified_conv,
norm_layer=norm_layer)) rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
return nn.Sequential(*layers) return nn.Sequential(*layers)
...@@ -229,36 +284,14 @@ class ResNet(nn.Module): ...@@ -229,36 +284,14 @@ class ResNet(nn.Module):
x = self.layer4(x) x = self.layer4(x)
x = self.avgpool(x) x = self.avgpool(x)
x = x.view(x.size(0), -1) #x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
if self.drop:
x = self.drop(x)
x = self.fc(x) x = self.fc(x)
return x return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, root='~/.encoding/models', **kwargs): def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
......
"""ResNet variants"""
import torch
from .resnet import ResNet, Bottleneck
from ..model_store import get_model_file
__all__ = ['resnet50s', 'resnet101s', 'resnet152s',
'resnet50d']
# pspnet version of ResNet
def resnet50s(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNetS-50 model as in PSPNet.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
kwargs['deep_stem'] = True
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet50', root=root)), strict=False)
return model
def resnet101s(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNetS-101 model as in PSPNet.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
kwargs['deep_stem'] = True
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet101', root=root)), strict=False)
return model
def resnet152s(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNetS-152 model as in PSPNet.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
kwargs['deep_stem'] = True
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet152', root=root)), strict=False)
return model
# ResNet-D
def resnet50d(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3],
deep_stem=True, stem_width=32,
avg_down=True, **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnet50d', root=root)), strict=False)
return model
"""ResNeXt models"""
from .resnet import ResNet, Bottleneck
from ..model_store import get_model_file
__all__ = ['resnext50_32x4d', 'resnext101_32x8d']
def resnext50_32x4d(pretrained=False, root='~/.encoding/models', **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['bottleneck_width'] = 4
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnext50_32x4d', root=root)), strict=False)
return model
def resnext101_32x8d(pretrained=False, root='~/.encoding/models', **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['bottleneck_width'] = 8
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('resnext101_32x8d', root=root)), strict=False)
return model
import sys
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...nn import SyncBatchNorm, GlobalAvgPool2d
from ..model_store import get_model_file
__all__ = ['WideResNet', 'wideresnet38', 'wideresnet50']
ABN = partial(SyncBatchNorm, activation='leaky_relu', slope=0.01, sync=True, inplace=True)
class BasicBlock(nn.Module):
"""WideResNet BasicBlock
"""
def __init__(self, inplanes, planes, stride=1, dilation=1, expansion=1, downsample=None,
previous_dilation=1, dropout=0.0, **kwargs):
super(BasicBlock, self).__init__()
self.bn1 = ABN(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn2 = ABN(planes)
self.conv2 = nn.Conv2d(planes, planes * expansion, kernel_size=3,
stride=1, padding=previous_dilation, dilation=previous_dilation,
bias=False)
self.downsample = downsample
self.drop = None
if dropout > 0.0:
self.drop = nn.Dropout(dropout)
def forward(self, x):
if self.downsample:
bn1 = self.bn1(x)
residual = self.downsample(bn1)
else:
residual = x.clone()
bn1 = self.bn1(x)
out = self.conv1(bn1)
out = self.bn2(out)
if self.drop:
out = self.drops(out)
out = self.conv2(out)
out = out + residual
return out
class Bottleneck(nn.Module):
"""WideResNet BottleneckV1b
"""
# pylint: disable=unused-argument
def __init__(self, inplanes, planes, stride=1, dilation=1, expansion=4, dropout=0.0,
downsample=None, previous_dilation=1, **kwargs):
super(Bottleneck, self).__init__()
self.bn1 = ABN(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn2 = ABN(planes)
self.conv2 = nn.Conv2d(planes, planes*expansion//2, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
self.bn3 = ABN(planes*expansion//2)
self.conv3 = nn.Conv2d(planes*expansion//2, planes*expansion, kernel_size=1,
bias=False)
self.downsample = downsample
self.drop = None
if dropout > 0.0:
self.drop = nn.Dropout(dropout)
def forward(self, x):
if self.downsample:
bn1 = self.bn1(x)
residual = self.downsample(bn1)
else:
residual = x.clone()
bn1 = self.bn1(x)
out = self.conv1(bn1)
out = self.bn2(out)
out = self.conv2(out)
out = self.bn3(out)
if self.drop:
out = self.drop(out)
out = self.conv3(out)
out = out + residual
return out
class WideResNet(nn.Module):
""" Pre-trained WideResNet Model
featuremaps at conv5.
Parameters
----------
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
final_drop : float, default 0.0
Dropout ratio before the final classification layer.
Reference:
- Zifeng Wu, et al. "Wider or Deeper: Revisiting the ResNet Model for Visual Recognition"
- Samuel Rota Bulò, et al.
"In-Place Activated BatchNorm for Memory-Optimized Training of DNNs"
"""
# pylint: disable=unused-variable
def __init__(self, layers, classes=1000, dilated=False, **kwargs):
self.inplanes = 64
super(WideResNet, self).__init__()
self.mod1 = nn.Conv2d(3, 64, kernel_size=3, stride=1,
padding=1, bias=False)
self.pool2 = nn.MaxPool2d(3, stride=2, padding=1)
self.mod2 = self._make_layer(2, BasicBlock, 128, layers[0])
self.pool3 = nn.MaxPool2d(3, stride=2, padding=1)
self.mod3 = self._make_layer(3, BasicBlock, 256, layers[1], stride=1)
self.mod4 = self._make_layer(4, BasicBlock, 512, layers[2], stride=2)
if dilated:
self.mod5 = self._make_layer(5, BasicBlock, 512, layers[3], stride=1, dilation=2,
expansion=2)
self.mod6 = self._make_layer(6, Bottleneck, 512, layers[4], stride=1, dilation=4,
expansion=4, dropout=0.3)
self.mod7 = self._make_layer(7, Bottleneck, 1024, layers[5], stride=1, dilation=4,
expansion=4, dropout=0.5)
else:
self.mod5 = self._make_layer(5, BasicBlock, 512, layers[3], stride=2, expansion=2)
self.mod6 = self._make_layer(6, Bottleneck, 512, layers[4], stride=2,
expansion=4, dropout=0.3)
self.mod7 = self._make_layer(7, Bottleneck, 1024, layers[5], stride=1, expansion=4,
dropout=0.5)
self.bn_out = ABN(4096)
self.avgpool = GlobalAvgPool2d()
self.fc = nn.Linear(4096, classes)
def _make_layer(self, stage_index, block, planes, blocks, stride=1, dilation=1, expansion=1,
dropout=0.0):
downsample = None
if stride != 1 or self.inplanes != planes * expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * expansion,
kernel_size=1, stride=stride, bias=False),
)
layers = []
if dilation in (1, 2):
layers.append(block(self.inplanes, planes, stride, dilation=1, expansion=expansion,
dropout=dropout, downsample=downsample, previous_dilation=dilation))
elif dilation == 4 and stage_index < 7:
layers.append(block(self.inplanes, planes, stride, dilation=2, expansion=expansion,
dropout=dropout, downsample=downsample, previous_dilation=dilation))
else:
assert stage_index == 7
layers.append(block(self.inplanes, planes, stride, dilation=dilation, expansion=expansion,
dropout=dropout, downsample=downsample, previous_dilation=dilation))
self.inplanes = planes * expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, expansion=expansion,
dropout=dropout, previous_dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.mod1(x)
x = self.pool2(x)
x = self.mod2(x)
x = self.pool3(x)
x = self.mod3(x)
x = self.mod4(x)
x = self.mod5(x)
x = self.mod6(x)
x = self.mod7(x)
x = self.bn_out(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def wideresnet38(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a WideResNet-38 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = WideResNet([3, 3, 6, 3, 1, 1], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('wideresnet38', root=root)), strict=False)
return model
def wideresnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a WideResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = WideResNet([3, 3, 6, 6, 3, 1], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('wideresnet50', root=root)), strict=False)
return model
# code adapted from https://github.com/jfzhang95/pytorch-deeplab-xception/
import math
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from ...nn import SyncBatchNorm, GlobalAvgPool2d
from ..model_store import get_model_file
__all__ = ['Xception65', 'Xception71', 'xception65']
def fixed_padding(inputs, kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs
class SeparableConv2d(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
groups=inplanes, bias=bias)
self.bn = norm_layer(inplanes)
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
x = self.conv1(x)
x = self.bn(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, norm_layer=None,
start_with_relu=True, grow_first=True, is_last=False):
super(Block, self).__init__()
if planes != inplanes or stride != 1:
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
self.skipbn = norm_layer(planes)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = inplanes
if grow_first:
if start_with_relu:
rep.append(self.relu)
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(planes))
filters = planes
for i in range(reps - 1):
if grow_first or start_with_relu:
rep.append(self.relu)
rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(planes))
if stride != 1:
rep.append(self.relu)
rep.append(SeparableConv2d(planes, planes, 3, 2, norm_layer=norm_layer))
rep.append(norm_layer(planes))
elif is_last:
rep.append(self.relu)
rep.append(SeparableConv2d(planes, planes, 3, 1, dilation, norm_layer=norm_layer))
rep.append(norm_layer(planes))
#if not start_with_relu:
# rep = rep[1:]
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x = x + skip
return x
class Xception65(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, output_stride=32, norm_layer=nn.BatchNorm2d):
super(Xception65, self).__init__()
if output_stride == 32:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block20_stride = 2
exit_block_dilations = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block20_stride = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
middle_block_dilation = 2
exit_block20_stride = 1
exit_block_dilations = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(64)
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
self.block2 = Block(128, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False,
grow_first=True)
#print('self.block2', self.block2)
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
start_with_relu=True, grow_first=True, is_last=True)
# Middle flow
midflowblocks = []
for i in range(4, 20):
midflowblocks.append(('block%d'%i, Block(728, 728, reps=3, stride=1,
dilation=middle_block_dilation,
norm_layer=norm_layer, start_with_relu=True,
grow_first=True)))
self.midflow = nn.Sequential(OrderedDict(midflowblocks))
# Exit flow
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn3 = norm_layer(1536)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn4 = norm_layer(1536)
self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn5 = norm_layer(2048)
self.avgpool = GlobalAvgPool2d()
self.fc = nn.Linear(2048, 1000)
# Init weights
self._init_weight()
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
# add relu here
x = self.relu(x)
#c1 = x
x = self.block2(x)
#c2 = x
x = self.block3(x)
# Middle flow
x = self.midflow(x)
#c3 = x
# Exit flow
x = self.block20(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, SyncBatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Xception71(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, output_stride=32, norm_layer=nn.BatchNorm2d):
super(Xception71, self).__init__()
if output_stride == 32:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block20_stride = 2
exit_block_dilations = (1, 1)
elif output_stride == 16:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block20_stride = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
middle_block_dilation = 2
exit_block20_stride = 1
exit_block_dilations = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
self.bn2 = norm_layer(64)
self.block1 = Block(64, 128, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False)
block2 = []
block2.append(Block(128, 256, reps=2, stride=1, norm_layer=norm_layer, start_with_relu=False,
grow_first=True))
block2.append(Block(256, 256, reps=2, stride=2, norm_layer=norm_layer, start_with_relu=False,
grow_first=True))
block2.append(Block(256, 728, reps=2, stride=1, norm_layer=norm_layer, start_with_relu=False,
grow_first=True))
self.block2 = nn.Sequential(*block2)
self.block3 = Block(728, 728, reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
start_with_relu=True, grow_first=True, is_last=True)
# Middle flow
midflowblocks = []
for i in range(4, 20):
midflowblocks.append(('block%d'%i, Block(728, 728, reps=3, stride=1,
dilation=middle_block_dilation,
norm_layer=norm_layer, start_with_relu=True,
grow_first=True)))
self.midflow = nn.Sequential(OrderedDict(midflowblocks))
# Exit flow
self.block20 = Block(728, 1024, reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
norm_layer=norm_layer, start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn3 = norm_layer(1536)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn4 = norm_layer(1536)
self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer)
self.bn5 = norm_layer(2048)
self.avgpool = GlobalAvgPool2d()
self.fc = nn.Linear(2048, 1000)
# Init weights
self._init_weight()
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
# add relu here
x = self.relu(x)
low_level_feat = x
x = self.block2(x)
x = self.block3(x)
# Middle flow
x = self.midflow(x)
# Exit flow
x = self.block20(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x#, low_level_feat
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, SyncBatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def xception65(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Xception65(**kwargs)
if pretrained:
model.load_state_dict(torch.load(get_model_file('xception65', root=root)))
return model
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
from torch.autograd import Variable
from ..nn import View
__all__ = ['cifar_resnet20']
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class Basicblock(nn.Module):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d):
super(Basicblock, self).__init__()
if inplanes != planes or stride !=1 :
self.downsample = True
self.residual_layer = nn.Conv2d(inplanes, planes,
kernel_size=1, stride=stride)
else:
self.downsample = False
conv_block=[]
conv_block+=[norm_layer(inplanes),
nn.ReLU(inplace=True),
conv3x3(inplanes, planes,stride=stride),
norm_layer(planes),
nn.ReLU(inplace=True),
conv3x3(planes, planes)]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, input):
if self.downsample:
residual = self.residual_layer(input)
else:
residual = input
return residual + self.conv_block(input)
class Bottleneck(nn.Module):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
if inplanes != planes*self.expansion or stride !=1 :
self.downsample = True
self.residual_layer = nn.Conv2d(inplanes,
planes * self.expansion, kernel_size=1, stride=stride)
else:
self.downsample = False
conv_block = []
conv_block += [norm_layer(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, planes, kernel_size=1,
stride=1, bias=False)]
conv_block += [norm_layer(planes),
nn.ReLU(inplace=True),
nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)]
conv_block += [norm_layer(planes),
nn.ReLU(inplace=True),
nn.Conv2d(planes, planes * self.expansion,
kernel_size=1, stride=1, bias=False)]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
if self.downsample:
residual = self.residual_layer(x)
else:
residual = x
return residual + self.conv_block(x)
class CIFAR_ResNet(nn.Module):
def __init__(self, block=Basicblock, num_blocks=[2,2,2], width_factor = 1,
num_classes=10, norm_layer=torch.nn.BatchNorm2d):
super(CIFAR_ResNet, self).__init__()
self.expansion = block.expansion
self.inplanes = int(width_factor * 16)
strides = [1, 2, 2]
model = []
# Conv_1
model += [nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1),
norm_layer(self.inplanes),
nn.ReLU(inplace=True)]
# Residual units
model += [self._residual_unit(block, self.inplanes, num_blocks[0],
strides[0], norm_layer=norm_layer)]
for i in range(2):
model += [self._residual_unit(
block, int(2*self.inplanes/self.expansion),
num_blocks[i+1], strides[i+1], norm_layer=norm_layer)]
# Last conv layer
model += [norm_layer(self.inplanes),
nn.ReLU(inplace=True),
nn.AvgPool2d(8),
View(-1, self.inplanes),
nn.Linear(self.inplanes, num_classes)]
self.model = nn.Sequential(*model)
def _residual_unit(self, block, planes, n_blocks, stride, norm_layer):
strides = [stride] + [1]*(n_blocks-1)
layers = []
for i in range(n_blocks):
layers += [block(self.inplanes, planes, strides[i], norm_layer=norm_layer)]
self.inplanes = self.expansion*planes
return nn.Sequential(*layers)
def forward(self, input):
return self.model(input)
def cifar_resnet20(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a CIFAR ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = CIFAR_ResNet(Bottleneck, [3, 3, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.load(
get_model_file('cifar_resnet20', root=root)), strict=False)
return model
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..nn import Encoding, View, Normalize from ..nn import Encoding, View, Normalize
from . import resnet from .backbone import resnet
__all__ = ['DeepTen', 'get_deepten', 'get_deepten_resnet50_minc'] __all__ = ['DeepTen', 'get_deepten', 'get_deepten_resnet50_minc']
......
...@@ -3,16 +3,32 @@ from __future__ import print_function ...@@ -3,16 +3,32 @@ from __future__ import print_function
__all__ = ['get_model_file', 'purge'] __all__ = ['get_model_file', 'purge']
import os import os
import zipfile import zipfile
import portalocker
from ..utils import download, check_sha1 from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [ _model_sha1 = {name: checksum for checksum, name in [
# resnet
('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
# rectified
('9b5dc32b3b36ca1a6b41ecd4906830fc84dae8ed', 'resnet101_rt'),
# resnest
('fb9de5b360976e3e8bd3679d3e93c5409a5eff3c', 'resnest50'),
('966fb78c22323b0c68097c5c1242bd16d3e07fd5', 'resnest101'),
('d7fd712f5a1fcee5b3ce176026fbb6d0d278454a', 'resnest200'),
('b743074c6fc40f88d7f53e8affb350de38f4f49d', 'resnest269'),
# resnet other variants
('a75c83cfc89a56a4e8ba71b14f1ec67e923787b3', 'resnet50s'),
('03a0f310d6447880f1b22a83bd7d1aa7fc702c6e', 'resnet101s'),
('36670e8bc2428ecd5b7db1578538e2dd23872813', 'resnet152s'),
# other segmentation backbones
('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'), ('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'),
('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'), ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'),
# deepten paper
('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'), ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
# segmentation models
('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'), ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
...@@ -22,6 +38,9 @@ _model_sha1 = {name: checksum for checksum, name in [ ...@@ -22,6 +38,9 @@ _model_sha1 = {name: checksum for checksum, name in [
('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'), ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'),
('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'), ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'),
('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'), ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'),
# resnest segmentation models
('2225f09d0f40b9a168d9091652194bc35ec2a5a9', 'deeplab_resnest50_ade'),
('06ca799c8cc148fe0fafb5b6d052052935aa3cc8', 'deeplab_resnest101_ade'),
]} ]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
...@@ -50,10 +69,24 @@ def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): ...@@ -50,10 +69,24 @@ def get_model_file(name, root=os.path.join('~', '.encoding', 'models')):
file_path file_path
Path to the requested pretrained model file. Path to the requested pretrained model file.
""" """
if name not in _model_sha1:
from torchvision.models.resnet import model_urls
if name not in model_urls:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
root = os.path.expanduser(root)
return download(model_urls[name],
path=root,
overwrite=True)
file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not os.path.exists(root):
os.makedirs(root)
file_path = os.path.join(root, file_name+'.pth') file_path = os.path.join(root, file_name+'.pth')
sha1_hash = _model_sha1[name] sha1_hash = _model_sha1[name]
lockfile = os.path.join(root, file_name + '.lock')
with portalocker.Lock(lockfile, timeout=300):
if os.path.exists(file_path): if os.path.exists(file_path):
if check_sha1(file_path, sha1_hash): if check_sha1(file_path, sha1_hash):
return file_path return file_path
...@@ -63,9 +96,6 @@ def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): ...@@ -63,9 +96,6 @@ def get_model_file(name, root=os.path.join('~', '.encoding', 'models')):
else: else:
print('Model file {} is not found. Downloading.'.format(file_path)) print('Model file {} is not found. Downloading.'.format(file_path))
if not os.path.exists(root):
os.makedirs(root)
zip_file_path = os.path.join(root, file_name+'.zip') zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
if repo_url[-1] != '/': if repo_url[-1] != '/':
......
# pylint: disable=wildcard-import, unused-wildcard-import # pylint: disable=wildcard-import, unused-wildcard-import
from .resnet import * from .backbone import *
from .cifarresnet import * from .sseg import *
from .fcn import *
from .psp import *
from .encnet import *
from .deepten import * from .deepten import *
__all__ = ['get_model'] __all__ = ['get_model']
def get_model(name, **kwargs): def get_model(name, **kwargs):
"""Returns a pre-defined model by name """Returns a pre-defined model by name
...@@ -28,13 +24,29 @@ def get_model(name, **kwargs): ...@@ -28,13 +24,29 @@ def get_model(name, **kwargs):
The model. The model.
""" """
models = { models = {
'resnet18': resnet18, # resnet
'resnet34': resnet34,
'resnet50': resnet50, 'resnet50': resnet50,
'resnet101': resnet101, 'resnet101': resnet101,
'resnet152': resnet152, 'resnet152': resnet152,
'cifar_resnet20': cifar_resnet20, # resnest
'resnest50': resnest50,
'resnest101': resnest101,
'resnest200': resnest200,
'resnest269': resnest269,
# resnet other variants
'resnet50s': resnet50s,
'resnet101s': resnet101s,
'resnet152s': resnet152s,
'resnet50d': resnet50d,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
# other segmentation backbones
'xception65': xception65,
'wideresnet38': wideresnet38,
'wideresnet50': wideresnet50,
# deepten paper
'deepten_resnet50_minc': get_deepten_resnet50_minc, 'deepten_resnet50_minc': get_deepten_resnet50_minc,
# segmentation models
'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext, 'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext, 'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext, 'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
...@@ -42,6 +54,8 @@ def get_model(name, **kwargs): ...@@ -42,6 +54,8 @@ def get_model(name, **kwargs):
'encnet_resnet101_ade': get_encnet_resnet101_ade, 'encnet_resnet101_ade': get_encnet_resnet101_ade,
'fcn_resnet50_ade': get_fcn_resnet50_ade, 'fcn_resnet50_ade': get_fcn_resnet50_ade,
'psp_resnet50_ade': get_psp_resnet50_ade, 'psp_resnet50_ade': get_psp_resnet50_ade,
'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
} }
name = name.lower() name = name.lower()
if name not in models: if name not in models:
......
from .base import *
from .fcn import *
from .psp import *
from .fcfpn import *
from .atten import *
from .encnet import *
from .deeplab import *
from .upernet import *
def get_segmentation_model(name, **kwargs):
models = {
'fcn': get_fcn,
'psp': get_psp,
'fcfpn': get_fcfpn,
'atten': get_atten,
'encnet': get_encnet,
'upernet': get_upernet,
'deeplab': get_deeplab,
}
return models[name.lower()](**kwargs)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
from .base import BaseNet
from ...nn import ACFModule, ConcurrentModule, SyncBatchNorm
from .fcn import FCNHead
from .encnet import EncModule
__all__ = ['ATTEN', 'get_atten']
class ATTEN(BaseNet):
def __init__(self, nclass, backbone, nheads=8, nmixs=1, with_global=True,
with_enc=True, with_lateral=False, aux=True, se_loss=False,
norm_layer=SyncBatchNorm, **kwargs):
super(ATTEN, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
in_channels = 4096 if self.backbone.startswith('wideresnet') else 2048
self.head = ATTENHead(in_channels, nclass, norm_layer, self._up_kwargs,
nheads=nheads, nmixs=nmixs, with_global=with_global,
with_enc=with_enc, se_loss=se_loss,
lateral=with_lateral)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
def forward(self, x):
imsize = x.size()[2:]
#_, _, c3, c4 = self.base_forward(x)
#x = list(self.head(c4))
features = self.base_forward(x)
x = list(self.head(*features))
x[0] = interpolate(x[0], imsize, **self._up_kwargs)
if self.aux:
#auxout = self.auxlayer(c3)
auxout = self.auxlayer(features[2])
auxout = interpolate(auxout, imsize, **self._up_kwargs)
x.append(auxout)
return tuple(x)
def demo(self, x):
imsize = x.size()[2:]
features = self.base_forward(x)
return self.head.demo(*features)
class GlobalPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs):
super(GlobalPooling, self).__init__()
self._up_kwargs = up_kwargs
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
nn.ReLU(True))
def forward(self, x):
_, _, h, w = x.size()
pool = self.gap(x)
return interpolate(pool, (h,w), **self._up_kwargs)
class ATTENHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs,
nheads, nmixs, with_global,
with_enc, se_loss, lateral):
super(ATTENHead, self).__init__()
self.with_enc = with_enc
self.se_loss = se_loss
self._up_kwargs = up_kwargs
inter_channels = in_channels // 4
self.lateral = lateral
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
if lateral:
self.connect = nn.ModuleList([
nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
])
self.fusion = nn.Sequential(
nn.Conv2d(3*512, 512, kernel_size=3, padding=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True))
extended_channels = 0
self.atten = ACFModule(nheads, nmixs, inter_channels, inter_channels//nheads*nmixs,
inter_channels//nheads, norm_layer)
if with_global:
extended_channels = inter_channels
self.atten_layers = ConcurrentModule([
GlobalPooling(inter_channels, extended_channels, norm_layer, self._up_kwargs),
self.atten,
#nn.Sequential(*atten),
])
else:
self.atten_layers = nn.Sequential(*atten)
if with_enc:
self.encmodule = EncModule(inter_channels+extended_channels, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer)
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels+extended_channels, out_channels, 1))
def forward(self, *inputs):
feat = self.conv5(inputs[-1])
if self.lateral:
c2 = self.connect[0](inputs[1])
c3 = self.connect[1](inputs[2])
feat = self.fusion(torch.cat([feat, c2, c3], 1))
feat = self.atten_layers(feat)
if self.with_enc:
outs = list(self.encmodule(feat))
else:
outs = [feat]
outs[0] = self.conv6(outs[0])
return tuple(outs)
def demo(self, *inputs):
feat = self.conv5(inputs[-1])
if self.lateral:
c2 = self.connect[0](inputs[1])
c3 = self.connect[1](inputs[2])
feat = self.fusion(torch.cat([feat, c2, c3], 1))
attn = self.atten.demo(feat)
return attn
def get_atten(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""ATTEN model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_atten.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool, default False
Whether to load the pretrained weights for model.
pooling_mode : str, default 'avg'
Using 'max' pool or 'avg' pool in the Attention module.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_atten(dataset='pascal_voc', backbone='resnet50s', pretrained=False)
>>> print(model)
"""
# infer number of classes
from ...datasets import datasets, acronyms
model = ATTEN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('atten_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
...@@ -14,17 +14,47 @@ from torch.nn.parallel.data_parallel import DataParallel ...@@ -14,17 +14,47 @@ from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import parallel_apply from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.scatter_gather import scatter from torch.nn.parallel.scatter_gather import scatter
from . import resnet from ...utils import batch_pix_accuracy, batch_intersection_union
from ..utils import batch_pix_accuracy, batch_intersection_union
from ..backbone import *
up_kwargs = {'mode': 'bilinear', 'align_corners': True} up_kwargs = {'mode': 'bilinear', 'align_corners': True}
__all__ = ['BaseNet', 'MultiEvalModule'] __all__ = ['BaseNet', 'MultiEvalModule']
def get_backbone(name, **kwargs):
models = {
# resnet
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
# resnest
'resnest50': resnest50,
'resnest101': resnest101,
'resnest200': resnest200,
'resnest269': resnest269,
# resnet other variants
'resnet50s': resnet50s,
'resnet101s': resnet101s,
'resnet152s': resnet152s,
'resnet50d': resnet50d,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
# other segmentation backbones
'xception65': xception65,
'wideresnet38': wideresnet38,
'wideresnet50': wideresnet50,
}
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
net = models[name](**kwargs)
return net
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None, def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
base_size=520, crop_size=480, mean=[.485, .456, .406], base_size=520, crop_size=480, mean=[.485, .456, .406],
std=[.229, .224, .225], root='~/.encoding/models'): std=[.229, .224, .225], root='~/.encoding/models', *args, **kwargs):
super(BaseNet, self).__init__() super(BaseNet, self).__init__()
self.nclass = nclass self.nclass = nclass
self.aux = aux self.aux = aux
...@@ -35,18 +65,11 @@ class BaseNet(nn.Module): ...@@ -35,18 +65,11 @@ class BaseNet(nn.Module):
self.crop_size = crop_size self.crop_size = crop_size
# copying modules from pretrained models # copying modules from pretrained models
self.backbone = backbone self.backbone = backbone
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, self.pretrained = get_backbone(backbone, pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root) norm_layer=norm_layer, root=root,
elif backbone == 'resnet101': *args, **kwargs)
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, self.pretrained.fc = None
norm_layer=norm_layer, root=root)
elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options
self._up_kwargs = up_kwargs self._up_kwargs = up_kwargs
def base_forward(self, x): def base_forward(self, x):
......
...@@ -14,6 +14,28 @@ from .base import BaseNet ...@@ -14,6 +14,28 @@ from .base import BaseNet
from .fcn import FCNHead from .fcn import FCNHead
class DeepLabV3(BaseNet): class DeepLabV3(BaseNet):
r"""DeepLabV3
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
aux : bool
Auxiliary loss.
Reference:
Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation."
arXiv preprint arXiv:1706.05587 (2017).
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(DeepLabV3, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs) super(DeepLabV3, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = DeepLabV3Head(2048, nclass, norm_layer, self._up_kwargs) self.head = DeepLabV3Head(2048, nclass, norm_layer, self._up_kwargs)
...@@ -22,7 +44,7 @@ class DeepLabV3(BaseNet): ...@@ -22,7 +44,7 @@ class DeepLabV3(BaseNet):
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
_, _, c3, c4 = self.base_forward(x) c1, c2, c3, c4 = self.base_forward(x)
outputs = [] outputs = []
x = self.head(c4) x = self.head(c4)
...@@ -104,7 +126,7 @@ class ASPP_Module(nn.Module): ...@@ -104,7 +126,7 @@ class ASPP_Module(nn.Module):
y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1) y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1)
return self.project(y) return self.project(y)
def get_deeplab(dataset='pascal_voc', backbone='resnet50', pretrained=False, def get_deeplab(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs): root='~/.encoding/models', **kwargs):
acronyms = { acronyms = {
'pascal_voc': 'voc', 'pascal_voc': 'voc',
...@@ -112,10 +134,10 @@ def get_deeplab(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -112,10 +134,10 @@ def get_deeplab(dataset='pascal_voc', backbone='resnet50', pretrained=False,
'ade20k': 'ade', 'ade20k': 'ade',
} }
# infer number of classes # infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation from ...datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = DeepLabV3(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) model = DeepLabV3(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from ..model_store import get_model_file
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
get_model_file('deeplab_%s_%s'%(backbone, acronyms[dataset]), root=root))) get_model_file('deeplab_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model return model
...@@ -137,4 +159,42 @@ def get_deeplab_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwar ...@@ -137,4 +159,42 @@ def get_deeplab_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwar
>>> model = get_deeplab_resnet50_ade(pretrained=True) >>> model = get_deeplab_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_deeplab('ade20k', 'resnet50', pretrained, root=root, **kwargs) return get_deeplab('ade20k', 'resnet50s', pretrained, root=root, **kwargs)
def get_deeplab_resnest50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""DeepLabV3 model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest50', pretrained, root=root, **kwargs)
def get_deeplab_resnest101_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""DeepLabV3 model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest101', pretrained, root=root, **kwargs)
...@@ -11,7 +11,7 @@ import torch.nn.functional as F ...@@ -11,7 +11,7 @@ import torch.nn.functional as F
from .base import BaseNet from .base import BaseNet
from .fcn import FCNHead from .fcn import FCNHead
from ..nn import SyncBatchNorm, Encoding, Mean from ...nn import SyncBatchNorm, Encoding, Mean
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext', __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade', 'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade',
...@@ -112,7 +112,7 @@ class EncHead(nn.Module): ...@@ -112,7 +112,7 @@ class EncHead(nn.Module):
return tuple(outs) return tuple(outs)
def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, def get_encnet(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs): root='~/.encoding/models', **kwargs):
r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_ <https://arxiv.org/pdf/1803.08904.pdf>`_
...@@ -121,8 +121,8 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -121,8 +121,8 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
---------- ----------
dataset : str, default pascal_voc dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k) The dataset that model pretrained on. (pascal_voc, ade20k)
backbone : str, default resnet50 backbone : str, default resnet50s
The backbone network. (resnet50, 101, 152) The backbone network. (resnet50s, 101s, 152s)
pretrained : bool, default False pretrained : bool, default False
Whether to load the pretrained weights for model. Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models' root : str, default '~/.encoding/models'
...@@ -131,12 +131,12 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -131,12 +131,12 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
Examples Examples
-------- --------
>>> model = get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False) >>> model = get_encnet(dataset='pascal_voc', backbone='resnet50s', pretrained=False)
>>> print(model) >>> print(model)
""" """
kwargs['lateral'] = True if dataset.lower().startswith('p') else False kwargs['lateral'] = True if dataset.lower().startswith('p') else False
# infer number of classes # infer number of classes
from ..datasets import datasets, acronyms from ...datasets import datasets, acronyms
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
...@@ -161,7 +161,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', ** ...@@ -161,7 +161,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True) >>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=True, return get_encnet('pcontext', 'resnet50s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs) base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -181,7 +181,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', * ...@@ -181,7 +181,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True) >>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=True, return get_encnet('pcontext', 'resnet101s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs) base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -221,7 +221,7 @@ def get_encnet_resnet101_ade(pretrained=False, root='~/.encoding/models', **kwar ...@@ -221,7 +221,7 @@ def get_encnet_resnet101_ade(pretrained=False, root='~/.encoding/models', **kwar
>>> model = get_encnet_resnet50_ade(pretrained=True) >>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('ade20k', 'resnet101', pretrained, root=root, aux=True, return get_encnet('ade20k', 'resnet101s', pretrained, root=root, aux=True,
base_size=640, crop_size=576, **kwargs) base_size=640, crop_size=576, **kwargs)
def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwargs):
...@@ -241,5 +241,5 @@ def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwar ...@@ -241,5 +241,5 @@ def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwar
>>> model = get_encnet_resnet50_ade(pretrained=True) >>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_encnet('ade20k', 'resnet152', pretrained, root=root, aux=True, return get_encnet('ade20k', 'resnet152s', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs) base_size=520, crop_size=480, **kwargs)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from .base import BaseNet
torch_ver = torch.__version__[:3]
__all__ = ['FCFPN', 'get_fcfpn', 'get_fcfpn_50_ade']
class FCFPN(BaseNet):
r"""Fully Convolutional Networks for Semantic Segmentation
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
Reference:
Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks
for semantic segmentation." *CVPR*, 2015
Examples
--------
>>> model = FCFPN(nclass=21, backbone='resnet50')
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCFPN, self).__init__(nclass, backbone, aux, se_loss, dilated=False, norm_layer=norm_layer)
self.head = FCFPNHead(nclass, norm_layer, up_kwargs=self._up_kwargs)
assert not aux, "FCFPN does not support aux loss"
def forward(self, x):
imsize = x.size()[2:]
features = self.base_forward(x)
x = list(self.head(*features))
x[0] = upsample(x[0], imsize, **self._up_kwargs)
return tuple(x)
class FCFPNHead(nn.Module):
def __init__(self, out_channels, norm_layer=None, fpn_inchannels=[256, 512, 1024, 2048],
fpn_dim=256, up_kwargs=None):
super(FCFPNHead, self).__init__()
# bilinear upsample options
assert up_kwargs is not None
self._up_kwargs = up_kwargs
fpn_lateral = []
for fpn_inchannel in fpn_inchannels[:-1]:
fpn_lateral.append(nn.Sequential(
nn.Conv2d(fpn_inchannel, fpn_dim, kernel_size=1, bias=False),
norm_layer(fpn_dim),
nn.ReLU(inplace=True),
))
self.fpn_lateral = nn.ModuleList(fpn_lateral)
fpn_out = []
for _ in range(len(fpn_inchannels) - 1):
fpn_out.append(nn.Sequential(
nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1, bias=False),
norm_layer(fpn_dim),
nn.ReLU(inplace=True),
))
self.fpn_out = nn.ModuleList(fpn_out)
self.c4conv = nn.Sequential(nn.Conv2d(fpn_inchannels[-1], fpn_dim, 3, padding=1, bias=False),
norm_layer(fpn_dim),
nn.ReLU())
inter_channels = len(fpn_inchannels) * fpn_dim
self.conv5 = nn.Sequential(nn.Conv2d(inter_channels, 512, 3, padding=1, bias=False),
norm_layer(512),
nn.ReLU(),
nn.Dropout2d(0.1, False),
nn.Conv2d(512, out_channels, 1))
def forward(self, *inputs):
c4 = inputs[-1]
#se_pred = False
if hasattr(self, 'extramodule'):
#if self.extramodule.se_loss:
# se_pred = True
# feat, se_out = self.extramodule(feat)
#else:
c4 = self.extramodule(c4)
feat = self.c4conv(c4)
c1_size = inputs[0].size()[2:]
feat_up = upsample(feat, c1_size, **self._up_kwargs)
fpn_features = [feat_up]
# c4, c3, c2, c1
for i in reversed(range(len(inputs) - 1)):
feat_i = self.fpn_lateral[i](inputs[i])
feat = upsample(feat, feat_i.size()[2:], **self._up_kwargs)
feat = feat + feat_i
# upsample to the same size with c1
feat_up = upsample(self.fpn_out[i](feat), c1_size, **self._up_kwargs)
fpn_features.append(feat_up)
fpn_features = torch.cat(fpn_features, 1)
#if se_pred:
# return (self.conv5(fpn_features), se_out)
return (self.conv5(fpn_features), )
def get_fcfpn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""FCFPN model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcfpn.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcfpn(dataset='pascal_voc', backbone='resnet50s', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
}
# infer number of classes
from ...datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCFPN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('fcfpn_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_fcfpn_50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcfpn_50_ade(pretrained=True)
>>> print(model)
"""
return get_fcfpn('ade20k', 'resnet50s', pretrained)
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from ..nn import ConcurrentModule, SyncBatchNorm from ...nn import ConcurrentModule, SyncBatchNorm
from .base import BaseNet from .base import BaseNet
...@@ -23,8 +23,8 @@ class FCN(BaseNet): ...@@ -23,8 +23,8 @@ class FCN(BaseNet):
nclass : int nclass : int
Number of categories for the training dataset. Number of categories for the training dataset.
backbone : string backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', Pre-trained dilated backbone network type (default:'resnet50s'; 'resnet50s',
'resnet101' or 'resnet152'). 'resnet101s' or 'resnet152s').
norm_layer : object norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
...@@ -36,12 +36,13 @@ class FCN(BaseNet): ...@@ -36,12 +36,13 @@ class FCN(BaseNet):
Examples Examples
-------- --------
>>> model = FCN(nclass=21, backbone='resnet50') >>> model = FCN(nclass=21, backbone='resnet50s')
>>> print(model) >>> print(model)
""" """
def __init__(self, nclass, backbone, aux=True, se_loss=False, with_global=False, def __init__(self, nclass, backbone, aux=True, se_loss=False, with_global=False,
norm_layer=SyncBatchNorm, **kwargs): norm_layer=SyncBatchNorm, *args, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs) super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer,
*args, **kwargs)
self.head = FCNHead(2048, nclass, norm_layer, self._up_kwargs, with_global) self.head = FCNHead(2048, nclass, norm_layer, self._up_kwargs, with_global)
if aux: if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer) self.auxlayer = FCNHead(1024, nclass, norm_layer)
...@@ -109,7 +110,7 @@ class FCNHead(nn.Module): ...@@ -109,7 +110,7 @@ class FCNHead(nn.Module):
return self.conv5(x) return self.conv5(x)
def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, def get_fcn(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs): root='~/.encoding/models', **kwargs):
r"""FCN model from the paper `"Fully Convolutional Network for semantic segmentation" r"""FCN model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf>`_ <https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf>`_
...@@ -123,11 +124,11 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, ...@@ -123,11 +124,11 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
Location for keeping the model parameters. Location for keeping the model parameters.
Examples Examples
-------- --------
>>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False) >>> model = get_fcn(dataset='pascal_voc', backbone='resnet50s', pretrained=False)
>>> print(model) >>> print(model)
""" """
# infer number of classes # infer number of classes
from ..datasets import datasets, acronyms from ...datasets import datasets, acronyms
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
...@@ -152,7 +153,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa ...@@ -152,7 +153,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True) >>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs) return get_fcn('pcontext', 'resnet50s', pretrained, root=root, aux=False, **kwargs)
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
...@@ -171,4 +172,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -171,4 +172,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_fcn_resnet50_ade(pretrained=True) >>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs) return get_fcn('ade20k', 'resnet50s', pretrained, root=root, **kwargs)
...@@ -12,7 +12,7 @@ from torch.nn.functional import interpolate ...@@ -12,7 +12,7 @@ from torch.nn.functional import interpolate
from .base import BaseNet from .base import BaseNet
from .fcn import FCNHead from .fcn import FCNHead
from ..nn import PyramidPooling from ...nn import PyramidPooling
class PSP(BaseNet): class PSP(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
...@@ -50,10 +50,10 @@ class PSPHead(nn.Module): ...@@ -50,10 +50,10 @@ class PSPHead(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv5(x) return self.conv5(x)
def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, def get_psp(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs): root='~/.encoding/models', **kwargs):
# infer number of classes # infer number of classes
from ..datasets import datasets, acronyms from ...datasets import datasets, acronyms
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained: if pretrained:
from .model_store import get_model_file from .model_store import get_model_file
...@@ -78,4 +78,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): ...@@ -78,4 +78,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_psp_resnet50_ade(pretrained=True) >>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model) >>> print(model)
""" """
return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs) return get_psp('ade20k', 'resnet50s', pretrained, root=root, **kwargs)
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from .base import BaseNet
from .fcfpn import FCFPNHead
from ...nn import PyramidPooling
torch_ver = torch.__version__[:3]
__all__ = ['UperNet', 'get_upernet', 'get_upernet_50_ade']
class UperNet(BaseNet):
r"""Fully Convolutional Networks for Semantic Segmentation
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50s'; 'resnet50s',
'resnet101s' or 'resnet152s').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
Reference:
Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks
for semantic segmentation." *CVPR*, 2015
Examples
--------
>>> model = UperNet(nclass=21, backbone='resnet50s')
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(UperNet, self).__init__(nclass, backbone, aux, se_loss, dilated=False, norm_layer=norm_layer)
self.head = UperNetHead(nclass, norm_layer, up_kwargs=self._up_kwargs)
assert not aux, "UperNet does not support aux loss"
def forward(self, x):
imsize = x.size()[2:]
features = self.base_forward(x)
x = list(self.head(*features))
x[0] = upsample(x[0], imsize, **self._up_kwargs)
return tuple(x)
class UperNetHead(FCFPNHead):
def __init__(self, out_channels, norm_layer=None, fpn_inchannels=[256, 512, 1024, 2048],
fpn_dim=256, up_kwargs=None):
fpn_inchannels[-1] = fpn_inchannels[-1] * 2
super(UperNetHead, self).__init__(out_channels, norm_layer, fpn_inchannels,
fpn_dim, up_kwargs)
self.extramodule = PyramidPooling(fpn_inchannels[-1] // 2, norm_layer, up_kwargs)
def get_upernet(dataset='pascal_voc', backbone='resnet50s', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""UperNet model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_upernet.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_upernet(dataset='pascal_voc', backbone='resnet50s', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
}
# infer number of classes
from ...datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = UperNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('upernet_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_upernet_50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_upernet_50_ade(pretrained=True)
>>> print(model)
"""
return get_upernet('ade20k', 'resnet50s', pretrained)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment