Commit 123720b6 authored by lizz's avatar lizz Committed by Kai Chen
Browse files

Download one copy in distributed environment (#76)



* Download one copy in distributed environment
Signed-off-by: default avatarinnerlee <363664470@qq.com>

* fix
Signed-off-by: default avatarinnerlee <363664470@qq.com>

* format
Signed-off-by: default avatarinnerlee <363664470@qq.com>
parent 92a81b62
import os
import os.path as osp import os.path as osp
import pkgutil import pkgutil
import time import time
...@@ -8,6 +9,8 @@ import mmcv ...@@ -8,6 +9,8 @@ import mmcv
import torch import torch
from torch.utils import model_zoo from torch.utils import model_zoo
from .utils import get_dist_info
open_mmlab_model_urls = { open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
...@@ -84,6 +87,20 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -84,6 +87,20 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
print(err_msg) print(err_msg)
def load_url_dist(url):
""" In distributed setting, this function only download checkpoint at
local rank 0 """
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url)
return checkpoint
def load_checkpoint(model, def load_checkpoint(model,
filename, filename,
map_location=None, map_location=None,
...@@ -114,12 +131,12 @@ def load_checkpoint(model, ...@@ -114,12 +131,12 @@ def load_checkpoint(model,
_urls = getattr(_zoo, 'model_urls') _urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls) model_urls.update(_urls)
model_name = filename[11:] model_name = filename[11:]
checkpoint = model_zoo.load_url(model_urls[model_name]) checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'): elif filename.startswith('open-mmlab://'):
model_name = filename[13:] model_name = filename[13:]
checkpoint = model_zoo.load_url(open_mmlab_model_urls[model_name]) checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
checkpoint = model_zoo.load_url(filename) checkpoint = load_url_dist(filename)
else: else:
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename)) raise IOError('{} is not a checkpoint file'.format(filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment