"python/vscode:/vscode.git/clone" did not exist on "616b59f384ad13b824fa8bb634444b43967f8c8a"
Commit 0e747be8 authored by Kai Chen's avatar Kai Chen
Browse files

update resnet backbone

parent e8397e43
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='FasterRCNN', type='FasterRCNN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='MaskRCNN', type='MaskRCNN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
...@@ -3,7 +3,7 @@ model = dict( ...@@ -3,7 +3,7 @@ model = dict(
type='RPN', type='RPN',
pretrained='modelzoo://resnet50', pretrained='modelzoo://resnet50',
backbone=dict( backbone=dict(
type='resnet', type='ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
......
from .resnet import resnet from .resnet import ResNet
__all__ = ['resnet'] __all__ = ['ResNet']
import logging import logging
import math
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
...@@ -27,7 +28,8 @@ class BasicBlock(nn.Module): ...@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
stride=1, stride=1,
dilation=1, dilation=1,
downsample=None, downsample=None,
style='pytorch'): style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -37,6 +39,7 @@ class BasicBlock(nn.Module): ...@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
assert not with_cp
def forward(self, x): def forward(self, x):
residual = x residual = x
...@@ -69,7 +72,6 @@ class Bottleneck(nn.Module): ...@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
style='pytorch', style='pytorch',
with_cp=False): with_cp=False):
"""Bottleneck block. """Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer. if it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
...@@ -174,64 +176,73 @@ def make_res_layer(block, ...@@ -174,64 +176,73 @@ def make_res_layer(block,
return nn.Sequential(*layers) return nn.Sequential(*layers)
class ResHead(nn.Module): class ResNet(nn.Module):
"""ResNet backbone.
def __init__(self,
block,
num_blocks,
stride=2,
dilation=1,
style='pytorch'):
self.layer4 = make_res_layer(
block,
1024,
512,
num_blocks,
stride=stride,
dilation=dilation,
style=style)
def forward(self, x):
return self.layer4(x)
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
class ResNet(nn.Module): arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, def __init__(self,
block, depth,
layers, num_stages=4,
strides=(1, 2, 2, 2), strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=-1,
style='pytorch', style='pytorch',
sync_bn=False, frozen_stages=-1,
with_cp=False, bn_eval=True,
strict_frozen=False): bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if not len(layers) == len(strides) == len(dilations): if depth not in self.arch_settings:
raise ValueError( raise KeyError('invalid depth {} for resnet'.format(depth))
'The number of layers, strides and dilations must be equal, ' assert num_stages >= 1 and num_stages <= 4
'but found have {} layers, {} strides and {} dilations'.format( block, stage_blocks = self.arch_settings[depth]
len(layers), len(strides), len(dilations))) stage_blocks = stage_blocks[:num_stages]
assert max(out_indices) < len(layers) assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.style = style self.style = style
self.sync_bn = sync_bn self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(layers):
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
stride = strides[i] stride = strides[i]
dilation = dilations[i] dilation = dilations[i]
layer_name = 'layer{}'.format(i + 1)
planes = 64 * 2**i planes = 64 * 2**i
res_layer = make_res_layer( res_layer = make_res_layer(
block, block,
...@@ -243,12 +254,11 @@ class ResNet(nn.Module): ...@@ -243,12 +254,11 @@ class ResNet(nn.Module):
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
self.with_cp = with_cp
self.strict_frozen = strict_frozen self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
...@@ -257,11 +267,9 @@ class ResNet(nn.Module): ...@@ -257,11 +267,9 @@ class ResNet(nn.Module):
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels kaiming_init(m)
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) constant_init(m, 1)
nn.init.constant_(m.bias, 0)
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
...@@ -283,11 +291,11 @@ class ResNet(nn.Module): ...@@ -283,11 +291,11 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
if not self.sync_bn: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
m.eval() m.eval()
if self.strict_frozen: if self.bn_frozen:
for params in m.parameters(): for params in m.parameters():
params.requires_grad = False params.requires_grad = False
if mode and self.frozen_stages >= 0: if mode and self.frozen_stages >= 0:
...@@ -303,39 +311,3 @@ class ResNet(nn.Module): ...@@ -303,39 +311,3 @@ class ResNet(nn.Module):
mod.eval() mod.eval()
for param in mod.parameters(): for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
resnet_cfg = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def resnet(depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(2, ),
frozen_stages=-1,
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
"""Constructs a ResNet model.
Args:
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
num_stages (int): num of resnet stages, normally 4
strides (list): strides of the first block of each stage
dilations (list): dilation of each stage
out_indices (list): output from which stages
"""
if depth not in resnet_cfg:
raise KeyError('invalid depth {} for resnet'.format(depth))
block, layers = resnet_cfg[depth]
model = ResNet(block, layers[:num_stages], strides, dilations, out_indices,
frozen_stages, style, sync_bn, with_cp, strict_frozen)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment