Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
742fd13c
Commit
742fd13c
authored
Apr 01, 2019
by
Sepehr Sameni
Committed by
Francisco Massa
Apr 01, 2019
Browse files
remove duplicate code from densenet (#827)
* remove duplicate code from densenet * correct indentation
parent
6f2f9213
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
56 deletions
+21
-56
torchvision/models/densenet.py
torchvision/models/densenet.py
+21
-56
No files found.
torchvision/models/densenet.py
View file @
742fd13c
...
...
@@ -117,6 +117,23 @@ class DenseNet(nn.Module):
return
out
def
_load_state_dict
(
model
,
model_url
):
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern
=
re
.
compile
(
r
'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
state_dict
=
model_zoo
.
load_url
(
model_url
)
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
if
res
:
new_key
=
res
.
group
(
1
)
+
res
.
group
(
2
)
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
def
densenet121
(
pretrained
=
False
,
**
kwargs
):
r
"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...
...
@@ -127,20 +144,7 @@ def densenet121(pretrained=False, **kwargs):
model
=
DenseNet
(
num_init_features
=
64
,
growth_rate
=
32
,
block_config
=
(
6
,
12
,
24
,
16
),
**
kwargs
)
if
pretrained
:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern
=
re
.
compile
(
r
'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
state_dict
=
model_zoo
.
load_url
(
model_urls
[
'densenet121'
])
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
if
res
:
new_key
=
res
.
group
(
1
)
+
res
.
group
(
2
)
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
_load_state_dict
(
model
,
model_urls
[
'densenet121'
])
return
model
...
...
@@ -154,20 +158,7 @@ def densenet169(pretrained=False, **kwargs):
model
=
DenseNet
(
num_init_features
=
64
,
growth_rate
=
32
,
block_config
=
(
6
,
12
,
32
,
32
),
**
kwargs
)
if
pretrained
:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern
=
re
.
compile
(
r
'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
state_dict
=
model_zoo
.
load_url
(
model_urls
[
'densenet169'
])
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
if
res
:
new_key
=
res
.
group
(
1
)
+
res
.
group
(
2
)
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
_load_state_dict
(
model
,
model_urls
[
'densenet169'
])
return
model
...
...
@@ -181,20 +172,7 @@ def densenet201(pretrained=False, **kwargs):
model
=
DenseNet
(
num_init_features
=
64
,
growth_rate
=
32
,
block_config
=
(
6
,
12
,
48
,
32
),
**
kwargs
)
if
pretrained
:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern
=
re
.
compile
(
r
'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
state_dict
=
model_zoo
.
load_url
(
model_urls
[
'densenet201'
])
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
if
res
:
new_key
=
res
.
group
(
1
)
+
res
.
group
(
2
)
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
_load_state_dict
(
model
,
model_urls
[
'densenet201'
])
return
model
...
...
@@ -208,18 +186,5 @@ def densenet161(pretrained=False, **kwargs):
model
=
DenseNet
(
num_init_features
=
96
,
growth_rate
=
48
,
block_config
=
(
6
,
12
,
36
,
24
),
**
kwargs
)
if
pretrained
:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern
=
re
.
compile
(
r
'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
state_dict
=
model_zoo
.
load_url
(
model_urls
[
'densenet161'
])
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
if
res
:
new_key
=
res
.
group
(
1
)
+
res
.
group
(
2
)
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
_load_state_dict
(
model
,
model_urls
[
'densenet161'
])
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment