Unverified Commit 51a0c23f authored by LXXXXR's avatar LXXXXR Committed by GitHub
Browse files

[Feature] Support load checkpoint from ceph (#778)

* support load checkpoint using ceph

* minor change
parent 276883f1
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import io
import os import os
import os.path as osp import os.path as osp
import pkgutil import pkgutil
...@@ -14,6 +15,7 @@ from torch.optim import Optimizer ...@@ -14,6 +15,7 @@ from torch.optim import Optimizer
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
from ..fileio import FileClient
from ..fileio import load as load_file from ..fileio import load as load_file
from ..parallel import is_module_wrapper from ..parallel import is_module_wrapper
from ..utils import mkdir_or_exist from ..utils import mkdir_or_exist
...@@ -145,6 +147,27 @@ def load_pavimodel_dist(model_path, map_location=None): ...@@ -145,6 +147,27 @@ def load_pavimodel_dist(model_path, map_location=None):
return checkpoint return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models(): def get_torchvision_models():
model_urls = dict() model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
...@@ -249,6 +272,9 @@ def _load_checkpoint(filename, map_location=None): ...@@ -249,6 +272,9 @@ def _load_checkpoint(filename, map_location=None):
elif filename.startswith('pavi://'): elif filename.startswith('pavi://'):
model_path = filename[7:] model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location) checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith('s3://'):
checkpoint = load_fileclient_dist(
filename, backend='ceph', map_location=map_location)
else: else:
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise IOError(f'{filename} is not a checkpoint file')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment