Unverified Commit 09661d97 authored by LXXXXR's avatar LXXXXR Committed by GitHub
Browse files

[Feature] support upload and download checkpoint from pavimodelcloud (#725)

* support upload and download checkpoint from pavimodelcloud

* try import pavi

* add unittest
parent 826d3a7b
......@@ -6,6 +6,7 @@ import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
......@@ -116,6 +117,34 @@ def load_url_dist(url, model_dir=None):
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(
downloaded_file, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
......@@ -217,6 +246,9 @@ def _load_checkpoint(filename, map_location=None):
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
elif filename.startswith('pavi://'):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
else:
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
......@@ -359,7 +391,6 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if is_module_wrapper(model):
model = model.module
......@@ -374,7 +405,30 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
if filename.startswith('pavi://'):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
import sys
from collections import OrderedDict
from unittest.mock import MagicMock
import pytest
import torch.nn as nn
from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict
from mmcv.runner.checkpoint import get_state_dict, load_pavimodel_dist
@MODULE_WRAPPERS.register_module()
......@@ -30,6 +33,15 @@ class Model(nn.Module):
self.conv = nn.Conv2d(3, 3, 1)
class Mockpavimodel(object):
def __init__(self, name='fakename'):
self.name = name
def download(self, file):
pass
def assert_tensor_equal(tensor_a, tensor_b):
assert tensor_a.eq(tensor_b).all()
......@@ -110,3 +122,14 @@ def test_get_state_dict():
wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.module.bias)
def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel()
import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment