Unverified Commit c17ed460 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #483 from hellock/master

Update evaluation hooks
parents 64b1c8b6 a6ec45ba
import os import os
import os.path as osp import os.path as osp
import shutil
import time
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict from mmcv.runner import Hook, obj_from_dict
from mmcv.parallel import scatter, collate from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
...@@ -29,42 +28,13 @@ class DistEvalHook(Hook): ...@@ -29,42 +28,13 @@ class DistEvalHook(Hook):
'dataset must be a Dataset object or a dict, not {}'.format( 'dataset must be a Dataset object or a dict, not {}'.format(
type(dataset))) type(dataset)))
self.interval = interval self.interval = interval
self.lock_dir = None
def _barrier(self, rank, world_size):
"""Due to some issues with `torch.distributed.barrier()`, we have to
implement this ugly barrier function.
"""
if rank == 0:
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
while not (osp.exists(tmp)):
time.sleep(1)
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
os.remove(tmp)
else:
tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
mmcv.dump([], tmp)
while osp.exists(tmp):
time.sleep(1)
def before_run(self, runner):
self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
if runner.rank == 0:
if osp.exists(self.lock_dir):
shutil.rmtree(self.lock_dir)
mmcv.mkdir_or_exist(self.lock_dir)
def after_run(self, runner):
if runner.rank == 0:
shutil.rmtree(self.lock_dir)
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval): if not self.every_n_epochs(runner, self.interval):
return return
runner.model.eval() runner.model.eval()
results = [None for _ in range(len(self.dataset))] results = [None for _ in range(len(self.dataset))]
if runner.rank == 0:
prog_bar = mmcv.ProgressBar(len(self.dataset)) prog_bar = mmcv.ProgressBar(len(self.dataset))
for idx in range(runner.rank, len(self.dataset), runner.world_size): for idx in range(runner.rank, len(self.dataset), runner.world_size):
data = self.dataset[idx] data = self.dataset[idx]
...@@ -79,12 +49,13 @@ class DistEvalHook(Hook): ...@@ -79,12 +49,13 @@ class DistEvalHook(Hook):
results[idx] = result results[idx] = result
batch_size = runner.world_size batch_size = runner.world_size
if runner.rank == 0:
for _ in range(batch_size): for _ in range(batch_size):
prog_bar.update() prog_bar.update()
if runner.rank == 0: if runner.rank == 0:
print('\n') print('\n')
self._barrier(runner.rank, runner.world_size) dist.barrier()
for i in range(1, runner.world_size): for i in range(1, runner.world_size):
tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i)) tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
tmp_results = mmcv.load(tmp_file) tmp_results = mmcv.load(tmp_file)
...@@ -96,8 +67,8 @@ class DistEvalHook(Hook): ...@@ -96,8 +67,8 @@ class DistEvalHook(Hook):
tmp_file = osp.join(runner.work_dir, tmp_file = osp.join(runner.work_dir,
'temp_{}.pkl'.format(runner.rank)) 'temp_{}.pkl'.format(runner.rank))
mmcv.dump(results, tmp_file) mmcv.dump(results, tmp_file)
self._barrier(runner.rank, runner.world_size) dist.barrier()
self._barrier(runner.rank, runner.world_size) dist.barrier()
def evaluate(self): def evaluate(self):
raise NotImplementedError raise NotImplementedError
...@@ -179,7 +150,13 @@ class CocoDistEvalmAPHook(DistEvalHook): ...@@ -179,7 +150,13 @@ class CocoDistEvalmAPHook(DistEvalHook):
cocoEval.evaluate() cocoEval.evaluate()
cocoEval.accumulate() cocoEval.accumulate()
cocoEval.summarize() cocoEval.summarize()
field = '{}_mAP'.format(res_type) metrics = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l']
runner.log_buffer.output[field] = cocoEval.stats[0] for i in range(len(metrics)):
key = '{}_{}'.format(res_type, metrics[i])
val = float('{:.3f}'.format(cocoEval.stats[i]))
runner.log_buffer.output[key] = val
runner.log_buffer.output['{}_mAP_copypaste'.format(res_type)] = (
'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
'{ap[4]:.3f} {ap[5]:.3f}').format(ap=cocoEval.stats[:6])
runner.log_buffer.ready = True runner.log_buffer.ready = True
os.remove(tmp_file) os.remove(tmp_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment