Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
8a086f02
Commit
8a086f02
authored
Dec 24, 2018
by
ThangVu
Browse files
add frozen stage for group norm
parent
628441b7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
26 deletions
+26
-26
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+26
-26
No files found.
mmdet/models/backbones/resnet.py
View file @
8a086f02
...
...
@@ -234,9 +234,9 @@ class ResNet(nn.Module):
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
style
=
'pytorch'
,
frozen_stages
=-
1
,
normalize
=
dict
(
type
=
'BN'
,
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
),
with_cp
=
False
):
...
...
@@ -245,7 +245,7 @@ class ResNet(nn.Module):
raise
KeyError
(
'invalid depth {} for resnet'
.
format
(
depth
))
assert
num_stages
>=
1
and
num_stages
<=
4
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
...
...
@@ -254,14 +254,14 @@ class ResNet(nn.Module):
if
normalize
[
'type'
]
==
'GN'
:
assert
'num_groups'
in
normalize
else
:
assert
(
set
([
'type'
,
'frozen_stages'
,
'bn_eval'
,
'bn_frozen'
])
assert
(
set
([
'type'
,
'bn_eval'
,
'bn_frozen'
])
==
set
(
normalize
))
self
.
out_indices
=
out_indices
self
.
style
=
style
self
.
with_cp
=
with_cp
self
.
frozen_stages
=
frozen_stages
if
normalize
[
'type'
]
==
'BN'
:
self
.
frozen_stages
=
normalize
[
'frozen_stages'
]
self
.
bn_eval
=
normalize
[
'bn_eval'
]
self
.
bn_frozen
=
normalize
[
'bn_frozen'
]
self
.
normalize
=
normalize
...
...
@@ -334,27 +334,27 @@ class ResNet(nn.Module):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
if
self
.
normalize
[
'type'
]
==
'BN'
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
bn1
.
parameters
():
if
self
.
normalize
[
'type'
]
==
'BN'
and
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
stem_norm
=
getattr
(
self
,
self
.
stem_norm_name
)
stem_norm
.
eval
()
for
param
in
stem_norm
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
mod
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
mod
.
eval
()
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
self
.
bn1
.
eval
()
self
.
bn1
.
weight
.
requires_grad
=
False
self
.
bn1
.
bias
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
mod
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
mod
.
eval
()
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
class
ResNetClassifier
(
ResNet
):
...
...
@@ -433,8 +433,8 @@ class ResNetClassifier(ResNet):
cf_state
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
if
'blobs'
in
cf_state
:
cf_state
=
cf_state
[
'blobs'
]
for
py_k
,
cf_k
in
mapping
.
items
():
print
(
'Loading {} to {}'
.
format
(
cf_k
,
py_k
))
for
i
,
(
py_k
,
cf_k
)
in
enumerate
(
mapping
.
items
()
,
1
)
:
print
(
'
[{}/{}]
Loading {} to {}'
.
format
(
i
,
len
(
mapping
),
cf_k
,
py_k
))
assert
py_k
in
py_state
and
cf_k
in
cf_state
py_state
[
py_k
]
=
torch
.
Tensor
(
cf_state
[
cf_k
])
self
.
load_state_dict
(
py_state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment