Unverified Commit d7f83550 authored by LXXXXR's avatar LXXXXR Committed by GitHub
Browse files

[Fix] Fix load_ckpt from pavi and s3 (#1020)

* fix load_ckpt

* revised according to comments

* revise according to comments

* fix typo
parent b36c4de1
...@@ -293,7 +293,8 @@ def load_from_http(filename, map_location=None, model_dir=None): ...@@ -293,7 +293,8 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://') @CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None): def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed """load checkpoint through the file path prefixed with pavi. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function download ckpt at all ranks to different temporary
directories.
Args: Args:
filename (str): checkpoint file path with pavi prefix filename (str): checkpoint file path with pavi prefix
...@@ -312,30 +313,20 @@ def load_from_pavi(filename, map_location=None): ...@@ -312,30 +313,20 @@ def load_from_pavi(filename, map_location=None):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please install pavi to load checkpoint from modelcloud.') 'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) model = modelcloud.get(model_path)
if rank == 0: with TemporaryDirectory() as tmp_dir:
model = modelcloud.get(model_path) downloaded_file = osp.join(tmp_dir, model.name)
with TemporaryDirectory() as tmp_dir: model.download(downloaded_file)
downloaded_file = osp.join(tmp_dir, model.name) checkpoint = torch.load(downloaded_file, map_location=map_location)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint return checkpoint
@CheckpointLoader.register_scheme(prefixes='s3://') @CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'): def load_from_ceph(filename, map_location=None, backend='ceph'):
"""load checkpoint through the file path prefixed with s3. In distributed """load checkpoint through the file path prefixed with s3. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function download ckpt at all ranks to different temporary
directories.
Args: Args:
filename (str): checkpoint file path with s3 prefix filename (str): checkpoint file path with s3 prefix
...@@ -346,21 +337,14 @@ def load_from_ceph(filename, map_location=None, backend='ceph'): ...@@ -346,21 +337,14 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
allowed_backends = ['ceph'] allowed_backends = ['ceph']
if backend not in allowed_backends: if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.') raise ValueError(f'Load from Backend {backend} is not supported.')
if rank == 0:
fileclient = FileClient(backend=backend) fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename)) buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location) checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint return checkpoint
...@@ -663,7 +647,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -663,7 +647,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
if filename.startswith('pavi://'): if filename.startswith('pavi://'):
try: try:
from pavi import modelcloud from pavi import modelcloud
from pavi.exception import NodeNotFoundError from pavi import exception
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please install pavi to load checkpoint from modelcloud.') 'Please install pavi to load checkpoint from modelcloud.')
...@@ -672,7 +656,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -672,7 +656,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
model_dir, model_name = osp.split(model_path) model_dir, model_name = osp.split(model_path)
try: try:
model = modelcloud.get(model_dir) model = modelcloud.get(model_dir)
except NodeNotFoundError: except exception.NodeNotFoundError:
model = root.create_training_model(model_dir) model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name) checkpoint_file = osp.join(tmp_dir, model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment