Commit 5055cdf2 authored by Kai Chen's avatar Kai Chen
Browse files

rename resnet style from fb/msra to pytorch/caffe

parent 8262d461
......@@ -27,7 +27,7 @@ class BasicBlock(nn.Module):
stride=1,
dilation=1,
downsample=None,
style='fb'):
style='pytorch'):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
......@@ -66,15 +66,16 @@ class Bottleneck(nn.Module):
stride=1,
dilation=1,
downsample=None,
style='fb',
style='pytorch',
with_cp=False):
"""Bottleneck block
if style is "fb", the stride-two layer is the 3x3 conv layer,
if style is "msra", the stride-two layer is the first 1x1 conv layer
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['fb', 'msra']
if style == 'fb':
assert style in ['pytorch', 'caffe']
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
......@@ -141,7 +142,7 @@ def make_res_layer(block,
blocks,
stride=1,
dilation=1,
style='fb',
style='pytorch',
with_cp=False):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
......@@ -175,7 +176,12 @@ def make_res_layer(block,
class ResHead(nn.Module):
def __init__(self, block, num_blocks, stride=2, dilation=1, style='fb'):
def __init__(self,
block,
num_blocks,
stride=2,
dilation=1,
style='pytorch'):
self.layer4 = make_res_layer(
block,
1024,
......@@ -198,7 +204,7 @@ class ResNet(nn.Module):
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
style='fb',
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
......@@ -237,7 +243,7 @@ class ResNet(nn.Module):
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
setattr(self, layer_name, res_layer)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
self.with_cp = with_cp
......@@ -314,7 +320,7 @@ def resnet(depth,
dilations=(1, 1, 1, 1),
out_indices=(2, ),
frozen_stages=-1,
style='fb',
style='pytorch',
sync_bn=False,
with_cp=False,
strict_frozen=False):
......
......@@ -8,7 +8,7 @@ model = dict(
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
......
......@@ -8,7 +8,7 @@ model = dict(
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
......
......@@ -8,7 +8,7 @@ model = dict(
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment