Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
3e1e297d
Unverified
Commit
3e1e297d
authored
Jul 13, 2019
by
Kai Chen
Committed by
GitHub
Jul 13, 2019
Browse files
use torchvision:// instead of modelzoo:// (#89)
parent
e5ca8846
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
10 deletions
+22
-10
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+22
-10
No files found.
mmcv/runner/checkpoint.py
View file @
3e1e297d
...
@@ -2,11 +2,13 @@ import os
...
@@ -2,11 +2,13 @@ import os
import
os.path
as
osp
import
os.path
as
osp
import
pkgutil
import
pkgutil
import
time
import
time
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
importlib
import
import_module
from
importlib
import
import_module
import
mmcv
import
mmcv
import
torch
import
torch
import
torchvision
from
torch.utils
import
model_zoo
from
torch.utils
import
model_zoo
from
.utils
import
get_dist_info
from
.utils
import
get_dist_info
...
@@ -85,7 +87,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
...
@@ -85,7 +87,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if
strict
:
if
strict
:
raise
RuntimeError
(
err_msg
)
raise
RuntimeError
(
err_msg
)
elif
logger
is
not
None
:
elif
logger
is
not
None
:
logger
.
warn
(
err_msg
)
logger
.
warn
ing
(
err_msg
)
else
:
else
:
print
(
err_msg
)
print
(
err_msg
)
...
@@ -104,6 +106,18 @@ def load_url_dist(url):
...
@@ -104,6 +106,18 @@ def load_url_dist(url):
return
checkpoint
return
checkpoint
def
get_torchvision_models
():
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
continue
_zoo
=
import_module
(
'torchvision.models.{}'
.
format
(
name
))
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
return
model_urls
def
load_checkpoint
(
model
,
def
load_checkpoint
(
model
,
filename
,
filename
,
map_location
=
None
,
map_location
=
None
,
...
@@ -124,17 +138,15 @@ def load_checkpoint(model,
...
@@ -124,17 +138,15 @@ def load_checkpoint(model,
"""
"""
# load checkpoint from modelzoo or file or url
# load checkpoint from modelzoo or file or url
if
filename
.
startswith
(
'modelzoo://'
):
if
filename
.
startswith
(
'modelzoo://'
):
import
torchvision
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
model_urls
=
dict
()
'use "torchvision://" instead'
)
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
model_urls
=
get_torchvision_models
()
torchvision
.
models
.
__path__
):
if
not
ispkg
:
_zoo
=
import_module
(
'torchvision.models.{}'
.
format
(
name
))
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
model_name
=
filename
[
11
:]
model_name
=
filename
[
11
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'torchvision://'
):
model_urls
=
get_torchvision_models
()
model_name
=
filename
[
14
:]
checkpoint
=
load_url_dist
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'open-mmlab://'
):
elif
filename
.
startswith
(
'open-mmlab://'
):
model_name
=
filename
[
13
:]
model_name
=
filename
[
13
:]
checkpoint
=
load_url_dist
(
open_mmlab_model_urls
[
model_name
])
checkpoint
=
load_url_dist
(
open_mmlab_model_urls
[
model_name
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment