Commit fd67644c authored by Kai Chen's avatar Kai Chen
Browse files

use torch.distributed.barrier() instead of the self-implemented one

parent 64b1c8b6
import os import os
import os.path as osp import os.path as osp
import shutil
import time
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict from mmcv.runner import Hook, obj_from_dict
from mmcv.parallel import scatter, collate from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
...@@ -29,36 +28,6 @@ class DistEvalHook(Hook): ...@@ -29,36 +28,6 @@ class DistEvalHook(Hook):
'dataset must be a Dataset object or a dict, not {}'.format( 'dataset must be a Dataset object or a dict, not {}'.format(
type(dataset))) type(dataset)))
self.interval = interval self.interval = interval
self.lock_dir = None
def _barrier(self, rank, world_size):
"""Due to some issues with `torch.distributed.barrier()`, we have to
implement this ugly barrier function.
"""
if rank == 0:
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
while not (osp.exists(tmp)):
time.sleep(1)
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
os.remove(tmp)
else:
tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
mmcv.dump([], tmp)
while osp.exists(tmp):
time.sleep(1)
def before_run(self, runner):
self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
if runner.rank == 0:
if osp.exists(self.lock_dir):
shutil.rmtree(self.lock_dir)
mmcv.mkdir_or_exist(self.lock_dir)
def after_run(self, runner):
if runner.rank == 0:
shutil.rmtree(self.lock_dir)
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval): if not self.every_n_epochs(runner, self.interval):
...@@ -84,7 +53,7 @@ class DistEvalHook(Hook): ...@@ -84,7 +53,7 @@ class DistEvalHook(Hook):
if runner.rank == 0: if runner.rank == 0:
print('\n') print('\n')
self._barrier(runner.rank, runner.world_size) dist.barrier()
for i in range(1, runner.world_size): for i in range(1, runner.world_size):
tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i)) tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
tmp_results = mmcv.load(tmp_file) tmp_results = mmcv.load(tmp_file)
...@@ -96,8 +65,8 @@ class DistEvalHook(Hook): ...@@ -96,8 +65,8 @@ class DistEvalHook(Hook):
tmp_file = osp.join(runner.work_dir, tmp_file = osp.join(runner.work_dir,
'temp_{}.pkl'.format(runner.rank)) 'temp_{}.pkl'.format(runner.rank))
mmcv.dump(results, tmp_file) mmcv.dump(results, tmp_file)
self._barrier(runner.rank, runner.world_size) dist.barrier()
self._barrier(runner.rank, runner.world_size) dist.barrier()
def evaluate(self): def evaluate(self):
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment