Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
1c2e665a
Unverified
Commit
1c2e665a
authored
Feb 10, 2021
by
lizz
Committed by
GitHub
Feb 10, 2021
Browse files
map_location for all (#826)
* map_location for all * format * hmm * map_location * back * doc * same
parent
999f2d08
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
11 deletions
+13
-11
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+11
-9
tests/test_load_model_zoo.py
tests/test_load_model_zoo.py
+2
-2
No files found.
mmcv/runner/checkpoint.py
View file @
1c2e665a
...
...
@@ -269,22 +269,23 @@ def load_from_http(filename, map_location=None, model_dir=None):
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional):
it's not use
.
map_location (str, optional):
Same as :func:`torch.load`
.
model_dir (string, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank
,
world_size
=
get_dist_info
()
rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
rank
))
if
rank
==
0
:
checkpoint
=
model_zoo
.
load_url
(
filename
,
model_dir
=
model_dir
)
checkpoint
=
model_zoo
.
load_url
(
filename
,
model_dir
=
model_dir
,
map_location
=
map_location
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
>
0
:
checkpoint
=
model_zoo
.
load_url
(
filename
,
model_dir
=
model_dir
)
checkpoint
=
model_zoo
.
load_url
(
filename
,
model_dir
=
model_dir
,
map_location
=
map_location
)
return
checkpoint
...
...
@@ -370,7 +371,7 @@ def load_from_torchvision(filename, map_location=None):
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional):
it's not use
.
map_location (str, optional):
Same as :func:`torch.load`
.
Returns:
dict or OrderedDict: The loaded checkpoint.
...
...
@@ -382,7 +383,7 @@ def load_from_torchvision(filename, map_location=None):
model_name
=
filename
[
11
:]
else
:
model_name
=
filename
[
14
:]
return
load_from_http
(
model_urls
[
model_name
])
return
load_from_http
(
model_urls
[
model_name
]
,
map_location
=
map_location
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'open-mmlab://'
,
'openmmlab://'
))
...
...
@@ -416,7 +417,7 @@ def load_from_openmmlab(filename, map_location=None):
model_url
=
model_urls
[
model_name
]
# check if is url
if
model_url
.
startswith
((
'http://'
,
'https://'
)):
checkpoint
=
load_from_http
(
model_url
)
checkpoint
=
load_from_http
(
model_url
,
map_location
=
map_location
)
else
:
filename
=
osp
.
join
(
_get_mmcv_home
(),
model_url
)
if
not
osp
.
isfile
(
filename
):
...
...
@@ -431,7 +432,7 @@ def load_from_mmcls(filename, map_location=None):
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional):
it's not use
.
map_location (str, optional):
Same as :func:`torch.load`
.
Returns:
dict or OrderedDict: The loaded checkpoint.
...
...
@@ -439,7 +440,8 @@ def load_from_mmcls(filename, map_location=None):
model_urls
=
get_mmcls_models
()
model_name
=
filename
[
8
:]
checkpoint
=
load_from_http
(
model_urls
[
model_name
])
checkpoint
=
load_from_http
(
model_urls
[
model_name
],
map_location
=
map_location
)
checkpoint
=
_process_mmcls_checkpoint
(
checkpoint
)
return
checkpoint
...
...
tests/test_load_model_zoo.py
View file @
1c2e665a
...
...
@@ -58,11 +58,11 @@ def test_get_deprecated_models():
}
def
load_from_http
(
url
):
def
load_from_http
(
url
,
map_location
=
None
):
return
'url:'
+
url
def
load_url
(
url
,
model_dir
=
None
):
def
load_url
(
url
,
map_location
=
None
,
model_dir
=
None
):
return
load_from_http
(
url
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment