Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
791c172a
Unverified
Commit
791c172a
authored
Jan 30, 2020
by
os-gabe
Committed by
GitHub
Jan 30, 2020
Browse files
Fixes #1797 by adding an init_weights keyword argument to Inception3 (#1832)
parent
f2600c2e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
14 deletions
+14
-14
torchvision/models/inception.py
torchvision/models/inception.py
+14
-14
No files found.
torchvision/models/inception.py
View file @
791c172a
...
@@ -65,7 +65,7 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
...
@@ -65,7 +65,7 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class
Inception3
(
nn
.
Module
):
class
Inception3
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
1000
,
aux_logits
=
True
,
transform_input
=
False
,
def
__init__
(
self
,
num_classes
=
1000
,
aux_logits
=
True
,
transform_input
=
False
,
inception_blocks
=
None
):
inception_blocks
=
None
,
init_weights
=
True
):
super
(
Inception3
,
self
).
__init__
()
super
(
Inception3
,
self
).
__init__
()
if
inception_blocks
is
None
:
if
inception_blocks
is
None
:
inception_blocks
=
[
inception_blocks
=
[
...
@@ -102,19 +102,19 @@ class Inception3(nn.Module):
...
@@ -102,19 +102,19 @@ class Inception3(nn.Module):
self
.
Mixed_7b
=
inception_e
(
1280
)
self
.
Mixed_7b
=
inception_e
(
1280
)
self
.
Mixed_7c
=
inception_e
(
2048
)
self
.
Mixed_7c
=
inception_e
(
2048
)
self
.
fc
=
nn
.
Linear
(
2048
,
num_classes
)
self
.
fc
=
nn
.
Linear
(
2048
,
num_classes
)
if
init_weights
:
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
)
or
isinstance
(
m
,
nn
.
Linear
):
if
isinstance
(
m
,
nn
.
Conv2d
)
or
isinstance
(
m
,
nn
.
Linear
):
import
scipy.stats
as
stats
import
scipy.stats
as
stats
stddev
=
m
.
stddev
if
hasattr
(
m
,
'stddev'
)
else
0.1
stddev
=
m
.
stddev
if
hasattr
(
m
,
'stddev'
)
else
0.1
X
=
stats
.
truncnorm
(
-
2
,
2
,
scale
=
stddev
)
X
=
stats
.
truncnorm
(
-
2
,
2
,
scale
=
stddev
)
values
=
torch
.
as_tensor
(
X
.
rvs
(
m
.
weight
.
numel
()),
dtype
=
m
.
weight
.
dtype
)
values
=
torch
.
as_tensor
(
X
.
rvs
(
m
.
weight
.
numel
()),
dtype
=
m
.
weight
.
dtype
)
values
=
values
.
view
(
m
.
weight
.
size
())
values
=
values
.
view
(
m
.
weight
.
size
())
with
torch
.
no_grad
():
with
torch
.
no_grad
():
m
.
weight
.
copy_
(
values
)
m
.
weight
.
copy_
(
values
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
_transform_input
(
self
,
x
):
def
_transform_input
(
self
,
x
):
if
self
.
transform_input
:
if
self
.
transform_input
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment