Commit 5055cdf2 authored by Kai Chen's avatar Kai Chen
Browse files

rename resnet style from fb/msra to pytorch/caffe

parent 8262d461
...@@ -27,7 +27,7 @@ class BasicBlock(nn.Module): ...@@ -27,7 +27,7 @@ class BasicBlock(nn.Module):
stride=1, stride=1,
dilation=1, dilation=1,
downsample=None, downsample=None,
style='fb'): style='pytorch'):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -66,15 +66,16 @@ class Bottleneck(nn.Module): ...@@ -66,15 +66,16 @@ class Bottleneck(nn.Module):
stride=1, stride=1,
dilation=1, dilation=1,
downsample=None, downsample=None,
style='fb', style='pytorch',
with_cp=False): with_cp=False):
"""Bottleneck block """Bottleneck block.
if style is "fb", the stride-two layer is the 3x3 conv layer,
if style is "msra", the stride-two layer is the first 1x1 conv layer If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
assert style in ['fb', 'msra'] assert style in ['pytorch', 'caffe']
if style == 'fb': if style == 'pytorch':
conv1_stride = 1 conv1_stride = 1
conv2_stride = stride conv2_stride = stride
else: else:
...@@ -141,7 +142,7 @@ def make_res_layer(block, ...@@ -141,7 +142,7 @@ def make_res_layer(block,
blocks, blocks,
stride=1, stride=1,
dilation=1, dilation=1,
style='fb', style='pytorch',
with_cp=False): with_cp=False):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
...@@ -175,7 +176,12 @@ def make_res_layer(block, ...@@ -175,7 +176,12 @@ def make_res_layer(block,
class ResHead(nn.Module): class ResHead(nn.Module):
def __init__(self, block, num_blocks, stride=2, dilation=1, style='fb'): def __init__(self,
block,
num_blocks,
stride=2,
dilation=1,
style='pytorch'):
self.layer4 = make_res_layer( self.layer4 = make_res_layer(
block, block,
1024, 1024,
...@@ -198,7 +204,7 @@ class ResNet(nn.Module): ...@@ -198,7 +204,7 @@ class ResNet(nn.Module):
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=-1, frozen_stages=-1,
style='fb', style='pytorch',
sync_bn=False, sync_bn=False,
with_cp=False, with_cp=False,
strict_frozen=False): strict_frozen=False):
...@@ -237,7 +243,7 @@ class ResNet(nn.Module): ...@@ -237,7 +243,7 @@ class ResNet(nn.Module):
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
setattr(self, layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1) self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
self.with_cp = with_cp self.with_cp = with_cp
...@@ -314,7 +320,7 @@ def resnet(depth, ...@@ -314,7 +320,7 @@ def resnet(depth,
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices=(2, ), out_indices=(2, ),
frozen_stages=-1, frozen_stages=-1,
style='fb', style='pytorch',
sync_bn=False, sync_bn=False,
with_cp=False, with_cp=False,
strict_frozen=False): strict_frozen=False):
......
...@@ -8,7 +8,7 @@ model = dict( ...@@ -8,7 +8,7 @@ model = dict(
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=1, frozen_stages=1,
style='fb'), style='pytorch'),
neck=dict( neck=dict(
type='FPN', type='FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
......
...@@ -8,7 +8,7 @@ model = dict( ...@@ -8,7 +8,7 @@ model = dict(
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=1, frozen_stages=1,
style='fb'), style='pytorch'),
neck=dict( neck=dict(
type='FPN', type='FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
......
...@@ -8,7 +8,7 @@ model = dict( ...@@ -8,7 +8,7 @@ model = dict(
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=1, frozen_stages=1,
style='fb'), style='pytorch'),
neck=dict( neck=dict(
type='FPN', type='FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment