Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
005bc473
Commit
005bc473
authored
Jan 02, 2018
by
Yun Chen
Committed by
Alykhan Tejani
Jan 02, 2018
Browse files
make weight initialization optional to speed vgg-construction (#377)
parent
2b2aa9c7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
2 deletions
+19
-2
torchvision/models/vgg.py
torchvision/models/vgg.py
+19
-2
No files found.
torchvision/models/vgg.py
View file @
005bc473
...
@@ -23,7 +23,7 @@ model_urls = {
...
@@ -23,7 +23,7 @@ model_urls = {
class
VGG
(
nn
.
Module
):
class
VGG
(
nn
.
Module
):
def
__init__
(
self
,
features
,
num_classes
=
1000
):
def
__init__
(
self
,
features
,
num_classes
=
1000
,
init_weights
=
True
):
super
(
VGG
,
self
).
__init__
()
super
(
VGG
,
self
).
__init__
()
self
.
features
=
features
self
.
features
=
features
self
.
classifier
=
nn
.
Sequential
(
self
.
classifier
=
nn
.
Sequential
(
...
@@ -35,7 +35,8 @@ class VGG(nn.Module):
...
@@ -35,7 +35,8 @@ class VGG(nn.Module):
nn
.
Dropout
(),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
num_classes
),
nn
.
Linear
(
4096
,
num_classes
),
)
)
self
.
_initialize_weights
()
if
init_weights
:
self
.
_initialize_weights
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
features
(
x
)
x
=
self
.
features
(
x
)
...
@@ -88,6 +89,8 @@ def vgg11(pretrained=False, **kwargs):
...
@@ -88,6 +89,8 @@ def vgg11(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'A'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'A'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
...
@@ -100,6 +103,8 @@ def vgg11_bn(pretrained=False, **kwargs):
...
@@ -100,6 +103,8 @@ def vgg11_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11_bn'
]))
...
@@ -112,6 +117,8 @@ def vgg13(pretrained=False, **kwargs):
...
@@ -112,6 +117,8 @@ def vgg13(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'B'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'B'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
...
@@ -124,6 +131,8 @@ def vgg13_bn(pretrained=False, **kwargs):
...
@@ -124,6 +131,8 @@ def vgg13_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13_bn'
]))
...
@@ -136,6 +145,8 @@ def vgg16(pretrained=False, **kwargs):
...
@@ -136,6 +145,8 @@ def vgg16(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'D'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'D'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
...
@@ -148,6 +159,8 @@ def vgg16_bn(pretrained=False, **kwargs):
...
@@ -148,6 +159,8 @@ def vgg16_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16_bn'
]))
...
@@ -160,6 +173,8 @@ def vgg19(pretrained=False, **kwargs):
...
@@ -160,6 +173,8 @@ def vgg19(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'E'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'E'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
...
@@ -172,6 +187,8 @@ def vgg19_bn(pretrained=False, **kwargs):
...
@@ -172,6 +187,8 @@ def vgg19_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19_bn'
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment