Unverified Commit d7f83550 authored by LXXXXR's avatar LXXXXR Committed by GitHub
Browse files

[Fix] Fix load_ckpt from pavi and s3 (#1020)

* fix load_ckpt

* revised according to comments

* revise according to comments

* fix typo
parent b36c4de1
......@@ -293,7 +293,8 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function only download checkpoint at local rank 0.
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with pavi prefix
......@@ -312,30 +313,20 @@ def load_from_pavi(filename, map_location=None):
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function only download checkpoint at local rank 0.
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with s3 prefix
......@@ -346,21 +337,14 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
......@@ -663,7 +647,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
from pavi import exception
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
......@@ -672,7 +656,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
except exception.NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment