Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
0e747be8
"python/vscode:/vscode.git/clone" did not exist on "616b59f384ad13b824fa8bb634444b43967f8c8a"
Commit
0e747be8
authored
Oct 10, 2018
by
Kai Chen
Browse files
update resnet backbone
parent
e8397e43
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
91 deletions
+63
-91
configs/faster_rcnn_r50_fpn_1x.py
configs/faster_rcnn_r50_fpn_1x.py
+1
-1
configs/mask_rcnn_r50_fpn_1x.py
configs/mask_rcnn_r50_fpn_1x.py
+1
-1
configs/rpn_r50_fpn_1x.py
configs/rpn_r50_fpn_1x.py
+1
-1
mmdet/models/backbones/__init__.py
mmdet/models/backbones/__init__.py
+2
-2
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+58
-86
No files found.
configs/faster_rcnn_r50_fpn_1x.py
View file @
0e747be8
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'FasterRCNN'
,
type
=
'FasterRCNN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
configs/mask_rcnn_r50_fpn_1x.py
View file @
0e747be8
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'MaskRCNN'
,
type
=
'MaskRCNN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
configs/rpn_r50_fpn_1x.py
View file @
0e747be8
...
@@ -3,7 +3,7 @@ model = dict(
...
@@ -3,7 +3,7 @@ model = dict(
type
=
'RPN'
,
type
=
'RPN'
,
pretrained
=
'modelzoo://resnet50'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
backbone
=
dict
(
type
=
'
r
es
n
et'
,
type
=
'
R
es
N
et'
,
depth
=
50
,
depth
=
50
,
num_stages
=
4
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
...
...
mmdet/models/backbones/__init__.py
View file @
0e747be8
from
.resnet
import
r
es
n
et
from
.resnet
import
R
es
N
et
__all__
=
[
'
r
es
n
et'
]
__all__
=
[
'
R
es
N
et'
]
mmdet/models/backbones/resnet.py
View file @
0e747be8
import
logging
import
logging
import
math
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
...
@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
...
@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
):
style
=
'pytorch'
,
with_cp
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
...
@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
...
@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
self
.
downsample
=
downsample
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
dilation
=
dilation
assert
not
with_cp
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
...
@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
):
"""Bottleneck block.
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
"""
...
@@ -174,64 +176,73 @@ def make_res_layer(block,
...
@@ -174,64 +176,73 @@ def make_res_layer(block,
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
class
ResHead
(
nn
.
Module
):
class
ResNet
(
nn
.
Module
):
"""ResNet backbone.
def
__init__
(
self
,
block
,
num_blocks
,
stride
=
2
,
dilation
=
1
,
style
=
'pytorch'
):
self
.
layer4
=
make_res_layer
(
block
,
1024
,
512
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
style
=
style
)
def
forward
(
self
,
x
):
return
self
.
layer4
(
x
)
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
class
ResNet
(
nn
.
Module
):
arch_settings
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
def
__init__
(
self
,
block
,
depth
,
layers
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=-
1
,
style
=
'pytorch'
,
style
=
'pytorch'
,
sync_bn
=
False
,
frozen_stages
=-
1
,
with_cp
=
False
,
bn_eval
=
True
,
strict_frozen
=
False
):
bn_frozen
=
False
,
with_cp
=
False
):
super
(
ResNet
,
self
).
__init__
()
super
(
ResNet
,
self
).
__init__
()
if
not
len
(
layers
)
==
len
(
strides
)
==
len
(
dilations
):
if
depth
not
in
self
.
arch_settings
:
raise
ValueError
(
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
'The number of layers, strides and dilations must be equal, '
assert
num_stages
>=
1
and
num_stages
<=
4
'but found have {} layers, {} strides and {} dilations'
.
format
(
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
len
(
layers
),
len
(
strides
),
len
(
dilations
)))
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
max
(
out_indices
)
<
len
(
layers
)
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
style
=
style
self
.
style
=
style
self
.
sync_bn
=
sync_bn
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
self
.
inplanes
=
64
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
layers
):
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
stage_blocks
):
stride
=
strides
[
i
]
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
dilation
=
dilations
[
i
]
layer_name
=
'layer{}'
.
format
(
i
+
1
)
planes
=
64
*
2
**
i
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
res_layer
=
make_res_layer
(
block
,
block
,
...
@@ -243,12 +254,11 @@ class ResNet(nn.Module):
...
@@ -243,12 +254,11 @@ class ResNet(nn.Module):
style
=
self
.
style
,
style
=
self
.
style
,
with_cp
=
with_cp
)
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
self
.
inplanes
=
planes
*
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
layers
)
-
1
)
self
.
with_cp
=
with_cp
self
.
strict_frozen
=
strict_frozen
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
stage_blocks
)
-
1
)
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
...
@@ -257,11 +267,9 @@ class ResNet(nn.Module):
...
@@ -257,11 +267,9 @@ class ResNet(nn.Module):
elif
pretrained
is
None
:
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
kaiming_init
(
m
)
nn
.
init
.
normal_
(
m
.
weight
,
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
constant_init
(
m
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
else
:
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
raise
TypeError
(
'pretrained must be a str or None'
)
...
@@ -283,11 +291,11 @@ class ResNet(nn.Module):
...
@@ -283,11 +291,11 @@ class ResNet(nn.Module):
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
super
(
ResNet
,
self
).
train
(
mode
)
if
not
self
.
sync_bn
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
m
.
eval
()
if
self
.
strict
_frozen
:
if
self
.
bn
_frozen
:
for
params
in
m
.
parameters
():
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
if
mode
and
self
.
frozen_stages
>=
0
:
...
@@ -303,39 +311,3 @@ class ResNet(nn.Module):
...
@@ -303,39 +311,3 @@ class ResNet(nn.Module):
mod
.
eval
()
mod
.
eval
()
for
param
in
mod
.
parameters
():
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
resnet_cfg
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
resnet
(
depth
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
2
,
),
frozen_stages
=-
1
,
style
=
'pytorch'
,
sync_bn
=
False
,
with_cp
=
False
,
strict_frozen
=
False
):
"""Constructs a ResNet model.
Args:
depth (int): depth of resnet, from {18, 34, 50, 101, 152}
num_stages (int): num of resnet stages, normally 4
strides (list): strides of the first block of each stage
dilations (list): dilation of each stage
out_indices (list): output from which stages
"""
if
depth
not
in
resnet_cfg
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
block
,
layers
=
resnet_cfg
[
depth
]
model
=
ResNet
(
block
,
layers
[:
num_stages
],
strides
,
dilations
,
out_indices
,
frozen_stages
,
style
,
sync_bn
,
with_cp
,
strict_frozen
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment