Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
3e1e297d
Unverified
Commit
3e1e297d
authored
Jul 13, 2019
by
Kai Chen
Committed by
GitHub
Jul 13, 2019
Browse files
use torchvision:// instead of modelzoo:// (#89)
parent
e5ca8846
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
10 deletions
+22
-10
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+22
-10
No files found.
mmcv/runner/checkpoint.py
View file @
3e1e297d
...
...
@@ -2,11 +2,13 @@ import os
import
os.path
as
osp
import
pkgutil
import
time
import
warnings
from
collections
import
OrderedDict
from
importlib
import
import_module
import
mmcv
import
torch
import
torchvision
from
torch.utils
import
model_zoo
from
.utils
import
get_dist_info
...
...
@@ -85,7 +87,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if
strict
:
raise
RuntimeError
(
err_msg
)
elif
logger
is
not
None
:
logger
.
warn
(
err_msg
)
logger
.
warn
ing
(
err_msg
)
else
:
print
(
err_msg
)
...
...
@@ -104,6 +106,18 @@ def load_url_dist(url):
return
checkpoint
def
get_torchvision_models
():
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
continue
_zoo
=
import_module
(
'torchvision.models.{}'
.
format
(
name
))
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
return
model_urls
def
load_checkpoint
(
model
,
filename
,
map_location
=
None
,
...
...
@@ -124,17 +138,15 @@ def load_checkpoint(model,
"""
# load checkpoint from modelzoo or file or url
if
filename
.
startswith
(
'modelzoo://'
):
import
torchvision
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
not
ispkg
:
_zoo
=
import_module
(
'torchvision.models.{}'
.
format
(
name
))
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
model_urls
=
get_torchvision_models
()
model_name
=
filename
[
11
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'torchvision://'
):
model_urls
=
get_torchvision_models
()
model_name
=
filename
[
14
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'open-mmlab://'
):
model_name
=
filename
[
13
:]
checkpoint
=
load_url_dist
(
open_mmlab_model_urls
[
model_name
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment