Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
03e25734
Unverified
Commit
03e25734
authored
Jun 24, 2019
by
Francisco Massa
Committed by
GitHub
Jun 24, 2019
Browse files
Update URL and add progress option (#1043)
parent
3254560b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
torchvision/models/mnasnet.py
torchvision/models/mnasnet.py
+12
-12
No files found.
torchvision/models/mnasnet.py
View file @
03e25734
...
...
@@ -8,10 +8,10 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS
=
{
"mnasnet0_5"
:
"https://
github.com/1e100/mnasnet_trainer/releases/download/v0.1
/mnasnet0.5_top1_67.592-7c6cb539b9.pth"
,
"https://
download.pytorch.org/models
/mnasnet0.5_top1_67.592-7c6cb539b9.pth"
,
"mnasnet0_75"
:
None
,
"mnasnet1_0"
:
"https://
github.com/1e100/mnasnet_trainer/releases/download/v0.1
/mnasnet1.0_top1_73.512-f206786ef8.pth"
,
"https://
download.pytorch.org/models
/mnasnet1.0_top1_73.512-f206786ef8.pth"
,
"mnasnet1_3"
:
None
}
...
...
@@ -143,41 +143,41 @@ class MNASNet(torch.nn.Module):
nn
.
init
.
zeros_
(
m
.
bias
)
def
_load_pretrained
(
model_name
,
model
):
def
_load_pretrained
(
model_name
,
model
,
progress
):
if
model_name
not
in
_MODEL_URLS
or
_MODEL_URLS
[
model_name
]
is
None
:
raise
ValueError
(
"No checkpoint is available for model type {}"
.
format
(
model_name
))
checkpoint_url
=
_MODEL_URLS
[
model_name
]
model
.
load_state_dict
(
load_state_dict_from_url
(
checkpoint_url
))
model
.
load_state_dict
(
load_state_dict_from_url
(
checkpoint_url
,
progress
=
progress
))
def
mnasnet0_5
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet0_5
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 0.5. """
model
=
MNASNet
(
0.5
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet0_5"
,
model
)
_load_pretrained
(
"mnasnet0_5"
,
model
,
progress
)
return
model
def
mnasnet0_75
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet0_75
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 0.75. """
model
=
MNASNet
(
0.75
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet0_75"
,
model
)
_load_pretrained
(
"mnasnet0_75"
,
model
,
progress
)
return
model
def
mnasnet1_0
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet1_0
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 1.0. """
model
=
MNASNet
(
1.0
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet1_0"
,
model
)
_load_pretrained
(
"mnasnet1_0"
,
model
,
progress
)
return
model
def
mnasnet1_3
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet1_3
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 1.3. """
model
=
MNASNet
(
1.3
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet1_3"
,
model
)
_load_pretrained
(
"mnasnet1_3"
,
model
,
progress
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment