Unverified Commit 09661d97 authored by LXXXXR's avatar LXXXXR Committed by GitHub
Browse files

[Feature] support upload and download checkpoint from pavimodelcloud (#725)

* support upload and download checkpoint from pavimodelcloud

* try import pavi

* add unittest
parent 826d3a7b
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
from tempfile import TemporaryDirectory
import torch import torch
import torchvision import torchvision
...@@ -116,6 +117,34 @@ def load_url_dist(url, model_dir=None): ...@@ -116,6 +117,34 @@ def load_url_dist(url, model_dir=None):
return checkpoint return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def get_torchvision_models(): def get_torchvision_models():
model_urls = dict() model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
...@@ -217,6 +246,9 @@ def _load_checkpoint(filename, map_location=None): ...@@ -217,6 +246,9 @@ def _load_checkpoint(filename, map_location=None):
checkpoint = _process_mmcls_checkpoint(checkpoint) checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename) checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
else: else:
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise IOError(f'{filename} is not a checkpoint file')
...@@ -359,7 +391,6 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -359,7 +391,6 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if is_module_wrapper(model): if is_module_wrapper(model):
model = model.module model = model.module
...@@ -374,7 +405,30 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -374,7 +405,30 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint['optimizer'] = {} checkpoint['optimizer'] = {}
for name, optim in optimizer.items(): for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict() checkpoint['optimizer'][name] = optim.state_dict()
# immediately flush buffer
with open(filename, 'wb') as f: if filename.startswith('pavi://'):
torch.save(checkpoint, f) try:
f.flush() from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
import sys
from collections import OrderedDict from collections import OrderedDict
from unittest.mock import MagicMock
import pytest
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict from mmcv.runner.checkpoint import get_state_dict, load_pavimodel_dist
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -30,6 +33,15 @@ class Model(nn.Module): ...@@ -30,6 +33,15 @@ class Model(nn.Module):
self.conv = nn.Conv2d(3, 3, 1) self.conv = nn.Conv2d(3, 3, 1)
class Mockpavimodel(object):
def __init__(self, name='fakename'):
self.name = name
def download(self, file):
pass
def assert_tensor_equal(tensor_a, tensor_b): def assert_tensor_equal(tensor_a, tensor_b):
assert tensor_a.eq(tensor_b).all() assert tensor_a.eq(tensor_b).all()
...@@ -110,3 +122,14 @@ def test_get_state_dict(): ...@@ -110,3 +122,14 @@ def test_get_state_dict():
wrapped_model.module.conv.module.weight) wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'], assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.module.bias) wrapped_model.module.conv.module.bias)
def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel()
import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment