Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
51a0c23f
Unverified
Commit
51a0c23f
authored
Jan 08, 2021
by
LXXXXR
Committed by
GitHub
Jan 08, 2021
Browse files
[Feature] Support load checkpoint from ceph (#778)
* support load checkpoint using ceph * minor change
parent
276883f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
0 deletions
+26
-0
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+26
-0
No files found.
mmcv/runner/checkpoint.py
View file @
51a0c23f
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
io
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
pkgutil
import
pkgutil
...
@@ -14,6 +15,7 @@ from torch.optim import Optimizer
...
@@ -14,6 +15,7 @@ from torch.optim import Optimizer
from
torch.utils
import
model_zoo
from
torch.utils
import
model_zoo
import
mmcv
import
mmcv
from
..fileio
import
FileClient
from
..fileio
import
load
as
load_file
from
..fileio
import
load
as
load_file
from
..parallel
import
is_module_wrapper
from
..parallel
import
is_module_wrapper
from
..utils
import
mkdir_or_exist
from
..utils
import
mkdir_or_exist
...
@@ -145,6 +147,27 @@ def load_pavimodel_dist(model_path, map_location=None):
...
@@ -145,6 +147,27 @@ def load_pavimodel_dist(model_path, map_location=None):
return
checkpoint
return
checkpoint
def
load_fileclient_dist
(
filename
,
backend
,
map_location
):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank
,
world_size
=
get_dist_info
()
rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
rank
))
allowed_backends
=
[
'ceph'
]
if
backend
not
in
allowed_backends
:
raise
ValueError
(
f
'Load from Backend
{
backend
}
is not supported.'
)
if
rank
==
0
:
fileclient
=
FileClient
(
backend
=
backend
)
buffer
=
io
.
BytesIO
(
fileclient
.
get
(
filename
))
checkpoint
=
torch
.
load
(
buffer
,
map_location
=
map_location
)
if
world_size
>
1
:
torch
.
distributed
.
barrier
()
if
rank
>
0
:
fileclient
=
FileClient
(
backend
=
backend
)
buffer
=
io
.
BytesIO
(
fileclient
.
get
(
filename
))
checkpoint
=
torch
.
load
(
buffer
,
map_location
=
map_location
)
return
checkpoint
def
get_torchvision_models
():
def
get_torchvision_models
():
model_urls
=
dict
()
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
...
@@ -249,6 +272,9 @@ def _load_checkpoint(filename, map_location=None):
...
@@ -249,6 +272,9 @@ def _load_checkpoint(filename, map_location=None):
elif
filename
.
startswith
(
'pavi://'
):
elif
filename
.
startswith
(
'pavi://'
):
model_path
=
filename
[
7
:]
model_path
=
filename
[
7
:]
checkpoint
=
load_pavimodel_dist
(
model_path
,
map_location
=
map_location
)
checkpoint
=
load_pavimodel_dist
(
model_path
,
map_location
=
map_location
)
elif
filename
.
startswith
(
's3://'
):
checkpoint
=
load_fileclient_dist
(
filename
,
backend
=
'ceph'
,
map_location
=
map_location
)
else
:
else
:
if
not
osp
.
isfile
(
filename
):
if
not
osp
.
isfile
(
filename
):
raise
IOError
(
f
'
{
filename
}
is not a checkpoint file'
)
raise
IOError
(
f
'
{
filename
}
is not a checkpoint file'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment