Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
005bc473
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0bb9b9147a1cf6dec81b7130979d33f25c5720a9"
Commit
005bc473
authored
Jan 02, 2018
by
Yun Chen
Committed by
Alykhan Tejani
Jan 02, 2018
Browse files
make weight initialization optional to speed vgg-construction (#377)
parent
2b2aa9c7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
2 deletions
+19
-2
torchvision/models/vgg.py
torchvision/models/vgg.py
+19
-2
No files found.
torchvision/models/vgg.py
View file @
005bc473
...
@@ -23,7 +23,7 @@ model_urls = {
...
@@ -23,7 +23,7 @@ model_urls = {
class
VGG
(
nn
.
Module
):
class
VGG
(
nn
.
Module
):
def
__init__
(
self
,
features
,
num_classes
=
1000
):
def
__init__
(
self
,
features
,
num_classes
=
1000
,
init_weights
=
True
):
super
(
VGG
,
self
).
__init__
()
super
(
VGG
,
self
).
__init__
()
self
.
features
=
features
self
.
features
=
features
self
.
classifier
=
nn
.
Sequential
(
self
.
classifier
=
nn
.
Sequential
(
...
@@ -35,6 +35,7 @@ class VGG(nn.Module):
...
@@ -35,6 +35,7 @@ class VGG(nn.Module):
nn
.
Dropout
(),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
num_classes
),
nn
.
Linear
(
4096
,
num_classes
),
)
)
if
init_weights
:
self
.
_initialize_weights
()
self
.
_initialize_weights
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -88,6 +89,8 @@ def vgg11(pretrained=False, **kwargs):
...
@@ -88,6 +89,8 @@ def vgg11(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'A'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'A'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
...
@@ -100,6 +103,8 @@ def vgg11_bn(pretrained=False, **kwargs):
...
@@ -100,6 +103,8 @@ def vgg11_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11_bn'
]))
...
@@ -112,6 +117,8 @@ def vgg13(pretrained=False, **kwargs):
...
@@ -112,6 +117,8 @@ def vgg13(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'B'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'B'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
...
@@ -124,6 +131,8 @@ def vgg13_bn(pretrained=False, **kwargs):
...
@@ -124,6 +131,8 @@ def vgg13_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13_bn'
]))
...
@@ -136,6 +145,8 @@ def vgg16(pretrained=False, **kwargs):
...
@@ -136,6 +145,8 @@ def vgg16(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'D'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'D'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
...
@@ -148,6 +159,8 @@ def vgg16_bn(pretrained=False, **kwargs):
...
@@ -148,6 +159,8 @@ def vgg16_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16_bn'
]))
...
@@ -160,6 +173,8 @@ def vgg19(pretrained=False, **kwargs):
...
@@ -160,6 +173,8 @@ def vgg19(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'E'
]),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'E'
]),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
...
@@ -172,6 +187,8 @@ def vgg19_bn(pretrained=False, **kwargs):
...
@@ -172,6 +187,8 @@ def vgg19_bn(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
if
pretrained
:
kwargs
[
'init_weights'
]
=
False
model
=
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
),
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19_bn'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19_bn'
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment