Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
03e25734
"...text-generation-inference.git" did not exist on "e68509add7ac04f8d6b4b75fab6fa65f47c2a76c"
Unverified
Commit
03e25734
authored
Jun 24, 2019
by
Francisco Massa
Committed by
GitHub
Jun 24, 2019
Browse files
Update URL and add progress option (#1043)
parent
3254560b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
torchvision/models/mnasnet.py
torchvision/models/mnasnet.py
+12
-12
No files found.
torchvision/models/mnasnet.py
View file @
03e25734
...
...
@@ -8,10 +8,10 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS
=
{
"mnasnet0_5"
:
"https://
github.com/1e100/mnasnet_trainer/releases/download/v0.1
/mnasnet0.5_top1_67.592-7c6cb539b9.pth"
,
"https://
download.pytorch.org/models
/mnasnet0.5_top1_67.592-7c6cb539b9.pth"
,
"mnasnet0_75"
:
None
,
"mnasnet1_0"
:
"https://
github.com/1e100/mnasnet_trainer/releases/download/v0.1
/mnasnet1.0_top1_73.512-f206786ef8.pth"
,
"https://
download.pytorch.org/models
/mnasnet1.0_top1_73.512-f206786ef8.pth"
,
"mnasnet1_3"
:
None
}
...
...
@@ -143,41 +143,41 @@ class MNASNet(torch.nn.Module):
nn
.
init
.
zeros_
(
m
.
bias
)
def
_load_pretrained
(
model_name
,
model
):
def
_load_pretrained
(
model_name
,
model
,
progress
):
if
model_name
not
in
_MODEL_URLS
or
_MODEL_URLS
[
model_name
]
is
None
:
raise
ValueError
(
"No checkpoint is available for model type {}"
.
format
(
model_name
))
checkpoint_url
=
_MODEL_URLS
[
model_name
]
model
.
load_state_dict
(
load_state_dict_from_url
(
checkpoint_url
))
model
.
load_state_dict
(
load_state_dict_from_url
(
checkpoint_url
,
progress
=
progress
))
def
mnasnet0_5
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet0_5
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 0.5. """
model
=
MNASNet
(
0.5
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet0_5"
,
model
)
_load_pretrained
(
"mnasnet0_5"
,
model
,
progress
)
return
model
def
mnasnet0_75
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet0_75
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 0.75. """
model
=
MNASNet
(
0.75
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet0_75"
,
model
)
_load_pretrained
(
"mnasnet0_75"
,
model
,
progress
)
return
model
def
mnasnet1_0
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet1_0
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 1.0. """
model
=
MNASNet
(
1.0
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet1_0"
,
model
)
_load_pretrained
(
"mnasnet1_0"
,
model
,
progress
)
return
model
def
mnasnet1_3
(
pretrained
=
False
,
**
kwargs
):
def
mnasnet1_3
(
pretrained
=
False
,
progress
=
True
,
**
kwargs
):
""" MNASNet with depth multiplier of 1.3. """
model
=
MNASNet
(
1.3
,
**
kwargs
)
if
pretrained
:
_load_pretrained
(
"mnasnet1_3"
,
model
)
_load_pretrained
(
"mnasnet1_3"
,
model
,
progress
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment