Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
04fde085
Commit
04fde085
authored
Feb 21, 2017
by
Sam Gross
Committed by
Soumith Chintala
Feb 21, 2017
Browse files
Add ImageNet trained VGG models and fix weight initialization (#62)
parent
683852d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
12 deletions
+66
-12
torchvision/models/vgg.py
torchvision/models/vgg.py
+66
-12
No files found.
torchvision/models/vgg.py
View file @
04fde085
import
torch.nn
as
nn
import
torch.utils.model_zoo
as
model_zoo
import
math
__all__
=
[
...
...
@@ -7,6 +9,14 @@ __all__ = [
]
model_urls
=
{
'vgg11'
:
'https://s3.amazonaws.com/pytorch/models/vgg11-fb7e83b2.pth'
,
'vgg13'
:
'https://s3.amazonaws.com/pytorch/models/vgg13-58758d87.pth'
,
'vgg16'
:
'https://s3.amazonaws.com/pytorch/models/vgg16-82412952.pth'
,
'vgg19'
:
'https://s3.amazonaws.com/pytorch/models/vgg19-341d7465.pth'
,
}
class
VGG
(
nn
.
Module
):
def
__init__
(
self
,
features
):
super
(
VGG
,
self
).
__init__
()
...
...
@@ -20,6 +30,7 @@ class VGG(nn.Module):
nn
.
ReLU
(
True
),
nn
.
Linear
(
4096
,
1000
),
)
self
.
_initialize_weights
()
def
forward
(
self
,
x
):
x
=
self
.
features
(
x
)
...
...
@@ -27,6 +38,21 @@ class VGG(nn.Module):
x
=
self
.
classifier
(
x
)
return
x
def
_initialize_weights
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
n
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
n
=
m
.
weight
.
size
(
1
)
m
.
weight
.
data
.
normal_
(
0
,
0.01
)
m
.
bias
.
data
.
zero_
()
def
make_layers
(
cfg
,
batch_norm
=
False
):
layers
=
[]
...
...
@@ -52,9 +78,16 @@ cfg = {
}
def
vgg11
():
"""VGG 11-layer model (configuration "A")"""
return
VGG
(
make_layers
(
cfg
[
'A'
]))
def
vgg11
(
pretrained
=
False
):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model
=
VGG
(
make_layers
(
cfg
[
'A'
]))
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
return
model
def
vgg11_bn
():
...
...
@@ -62,9 +95,16 @@ def vgg11_bn():
return
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
))
def
vgg13
():
"""VGG 13-layer model (configuration "B")"""
return
VGG
(
make_layers
(
cfg
[
'B'
]))
def
vgg13
(
pretrained
=
False
):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model
=
VGG
(
make_layers
(
cfg
[
'B'
]))
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
return
model
def
vgg13_bn
():
...
...
@@ -72,9 +112,16 @@ def vgg13_bn():
return
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
))
def
vgg16
():
"""VGG 16-layer model (configuration "D")"""
return
VGG
(
make_layers
(
cfg
[
'D'
]))
def
vgg16
(
pretrained
=
False
):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model
=
VGG
(
make_layers
(
cfg
[
'D'
]))
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
return
model
def
vgg16_bn
():
...
...
@@ -82,9 +129,16 @@ def vgg16_bn():
return
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
))
def
vgg19
():
"""VGG 19-layer model (configuration "E")"""
return
VGG
(
make_layers
(
cfg
[
'E'
]))
def
vgg19
(
pretrained
=
False
):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model
=
VGG
(
make_layers
(
cfg
[
'E'
]))
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
return
model
def
vgg19_bn
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment