Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
dca2d841
Commit
dca2d841
authored
Dec 31, 2018
by
ThangVu
Browse files
revise group norm (2)
parent
f64c9561
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
45 deletions
+44
-45
configs/mask_rcnn_r50_fpn_gn_2x.py
configs/mask_rcnn_r50_fpn_gn_2x.py
+1
-1
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+33
-36
mmdet/models/utils/norm.py
mmdet/models/utils/norm.py
+10
-8
No files found.
configs/mask_rcnn_r50_fpn_gn_2x.py
View file @
dca2d841
...
@@ -12,7 +12,7 @@ model = dict(
...
@@ -12,7 +12,7 @@ model = dict(
normalize
=
dict
(
normalize
=
dict
(
type
=
'GN'
,
type
=
'GN'
,
num_groups
=
32
,
num_groups
=
32
,
eval
=
False
,
eval
_mode
=
False
,
frozen
=
False
)),
frozen
=
False
)),
neck
=
dict
(
neck
=
dict
(
type
=
'FPN'
,
type
=
'FPN'
,
...
...
mmdet/models/backbones/resnet.py
View file @
dca2d841
...
@@ -31,8 +31,7 @@ class BasicBlock(nn.Module):
...
@@ -31,8 +31,7 @@ class BasicBlock(nn.Module):
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
,
with_cp
=
False
,
normalize
=
dict
(
type
=
'BN'
),
normalize
=
dict
(
type
=
'BN'
)):
frozen
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
...
@@ -51,10 +50,6 @@ class BasicBlock(nn.Module):
...
@@ -51,10 +50,6 @@ class BasicBlock(nn.Module):
self
.
dilation
=
dilation
self
.
dilation
=
dilation
assert
not
with_cp
assert
not
with_cp
if
frozen
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
identity
=
x
identity
=
x
...
@@ -85,8 +80,7 @@ class Bottleneck(nn.Module):
...
@@ -85,8 +80,7 @@ class Bottleneck(nn.Module):
downsample
=
None
,
downsample
=
None
,
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
,
with_cp
=
False
,
normalize
=
dict
(
type
=
'BN'
),
normalize
=
dict
(
type
=
'BN'
)):
frozen
=
False
):
"""Bottleneck block for ResNet.
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
...
@@ -134,10 +128,6 @@ class Bottleneck(nn.Module):
...
@@ -134,10 +128,6 @@ class Bottleneck(nn.Module):
self
.
with_cp
=
with_cp
self
.
with_cp
=
with_cp
self
.
normalize
=
normalize
self
.
normalize
=
normalize
if
frozen
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
def
_inner_forward
(
x
):
...
@@ -179,8 +169,7 @@ def make_res_layer(block,
...
@@ -179,8 +169,7 @@ def make_res_layer(block,
dilation
=
1
,
dilation
=
1
,
style
=
'pytorch'
,
style
=
'pytorch'
,
with_cp
=
False
,
with_cp
=
False
,
normalize
=
dict
(
type
=
'BN'
),
normalize
=
dict
(
type
=
'BN'
)):
frozen
=
False
):
downsample
=
None
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
downsample
=
nn
.
Sequential
(
...
@@ -203,8 +192,7 @@ def make_res_layer(block,
...
@@ -203,8 +192,7 @@ def make_res_layer(block,
downsample
,
downsample
,
style
=
style
,
style
=
style
,
with_cp
=
with_cp
,
with_cp
=
with_cp
,
normalize
=
normalize
,
normalize
=
normalize
))
frozen
=
frozen
))
inplanes
=
planes
*
block
.
expansion
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
layers
.
append
(
...
@@ -253,9 +241,10 @@ class ResNet(nn.Module):
...
@@ -253,9 +241,10 @@ class ResNet(nn.Module):
frozen_stages
=-
1
,
frozen_stages
=-
1
,
normalize
=
dict
(
normalize
=
dict
(
type
=
'BN'
,
type
=
'BN'
,
eval
=
True
,
eval
_mode
=
True
,
frozen
=
False
),
frozen
=
False
),
with_cp
=
False
):
with_cp
=
False
,
zero_init_residual
=
True
):
super
(
ResNet
,
self
).
__init__
()
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
...
@@ -268,12 +257,13 @@ class ResNet(nn.Module):
...
@@ -268,12 +257,13 @@ class ResNet(nn.Module):
self
.
out_indices
=
out_indices
self
.
out_indices
=
out_indices
assert
max
(
out_indices
)
<
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
style
=
style
self
.
style
=
style
self
.
with_cp
=
with_cp
self
.
frozen_stages
=
frozen_stages
self
.
is_frozen
=
[
i
<=
frozen_stages
for
i
in
range
(
num_stages
+
1
)]
assert
(
isinstance
(
normalize
,
dict
)
and
'eval_mode'
in
normalize
assert
(
isinstance
(
normalize
,
dict
)
and
'eval'
in
normalize
and
'frozen'
in
normalize
)
and
'frozen'
in
normalize
)
self
.
norm_eval
=
normalize
.
pop
(
'eval'
)
self
.
norm_eval
=
normalize
.
pop
(
'eval
_mode
'
)
self
.
normalize
=
normalize
self
.
normalize
=
normalize
self
.
with_cp
=
with_cp
self
.
zero_init_residual
=
zero_init_residual
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
inplanes
=
64
self
.
inplanes
=
64
...
@@ -294,13 +284,14 @@ class ResNet(nn.Module):
...
@@ -294,13 +284,14 @@ class ResNet(nn.Module):
dilation
=
dilation
,
dilation
=
dilation
,
style
=
self
.
style
,
style
=
self
.
style
,
with_cp
=
with_cp
,
with_cp
=
with_cp
,
normalize
=
normalize
,
normalize
=
normalize
)
frozen
=
self
.
is_frozen
[
i
+
1
])
self
.
inplanes
=
planes
*
self
.
block
.
expansion
self
.
inplanes
=
planes
*
self
.
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
res_layers
.
append
(
layer_name
)
self
.
_freeze_stages
()
self
.
feat_dim
=
self
.
block
.
expansion
*
64
*
2
**
(
self
.
feat_dim
=
self
.
block
.
expansion
*
64
*
2
**
(
len
(
self
.
stage_blocks
)
-
1
)
len
(
self
.
stage_blocks
)
-
1
)
...
@@ -313,11 +304,17 @@ class ResNet(nn.Module):
...
@@ -313,11 +304,17 @@ class ResNet(nn.Module):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
if
self
.
is_
frozen
[
0
]
:
if
self
.
frozen
_stages
>=
0
:
for
layer
in
[
self
.
conv1
,
stem_norm
]:
for
m
in
[
self
.
conv1
,
stem_norm
]:
for
param
in
layer
.
parameters
():
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
def
_freeze_stages
(
self
):
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
,
pretrained
=
None
):
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
...
@@ -326,15 +323,15 @@ class ResNet(nn.Module):
...
@@ -326,15 +323,15 @@ class ResNet(nn.Module):
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
kaiming_init
(
m
)
elif
(
isinstance
(
m
,
nn
.
BatchNorm2d
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm
,
nn
.
GroupNorm
)):
or
isinstance
(
m
,
nn
.
GroupNorm
)):
constant_init
(
m
,
1
)
constant_init
(
m
,
1
)
# zero init for last norm layer https://arxiv.org/abs/1706.02677
# zero init for last norm layer https://arxiv.org/abs/1706.02677
for
m
in
self
.
modules
():
if
self
.
zero_init_residual
:
if
isinstance
(
m
,
Bottleneck
)
or
isinstance
(
m
,
BasicBlock
):
for
m
in
self
.
modules
():
last_norm
=
getattr
(
m
,
m
.
norm_names
[
-
1
])
if
isinstance
(
m
,
(
Bottleneck
,
BasicBlock
)):
constant_init
(
last_norm
,
0
)
last_norm
=
getattr
(
m
,
m
.
norm_names
[
-
1
])
constant_init
(
last_norm
,
0
)
else
:
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
raise
TypeError
(
'pretrained must be a str or None'
)
...
@@ -357,7 +354,7 @@ class ResNet(nn.Module):
...
@@ -357,7 +354,7 @@ class ResNet(nn.Module):
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
super
(
ResNet
,
self
).
train
(
mode
)
if
mode
and
self
.
norm_eval
:
if
mode
and
self
.
norm_eval
:
for
m
od
in
self
.
modules
():
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
# trick: eval have effect on BatchNorm only
if
isinstance
(
self
,
nn
.
BatchNorm2d
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
od
.
eval
()
m
.
eval
()
mmdet/models/utils/norm.py
View file @
dca2d841
...
@@ -15,21 +15,23 @@ def build_norm_layer(cfg, num_features):
...
@@ -15,21 +15,23 @@ def build_norm_layer(cfg, num_features):
"""
"""
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
cfg_
=
cfg
.
copy
()
cfg_
=
cfg
.
copy
()
layer_type
=
cfg_
.
pop
(
'type'
)
layer_type
=
cfg_
.
pop
(
'type'
)
frozen
=
cfg_
.
pop
(
'frozen'
)
if
'frozen'
in
cfg_
else
False
if
layer_type
not
in
norm_cfg
:
raise
KeyError
(
'Unrecognized norm type {}'
.
format
(
layer_type
))
elif
norm_cfg
[
layer_type
]
is
None
:
raise
NotImplementedError
frozen
=
cfg_
.
pop
(
'frozen'
,
False
)
# args name matching
# args name matching
if
layer_type
==
'GN'
:
if
layer_type
in
[
'GN'
]
:
assert
'num_groups'
in
cfg
assert
'num_groups'
in
cfg
cfg_
.
setdefault
(
'num_channels'
,
num_features
)
cfg_
.
setdefault
(
'num_channels'
,
num_features
)
elif
layer_type
==
'BN'
:
elif
layer_type
in
[
'BN'
]
:
cfg_
.
setdefault
(
'num_features'
,
num_features
)
cfg_
.
setdefault
(
'num_features'
,
num_features
)
cfg_
.
setdefault
(
'eps'
,
1e-5
)
else
:
if
layer_type
not
in
norm_cfg
:
raise
KeyError
(
'Unrecognized norm type {}'
.
format
(
layer_type
))
elif
norm_cfg
[
layer_type
]
is
None
:
raise
NotImplementedError
raise
NotImplementedError
cfg_
.
setdefault
(
'eps'
,
1e-5
)
norm
=
norm_cfg
[
layer_type
](
**
cfg_
)
norm
=
norm_cfg
[
layer_type
](
**
cfg_
)
if
frozen
:
if
frozen
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment