Unverified Commit c17ed460 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #483 from hellock/master

Update evaluation hooks
parents 64b1c8b6 a6ec45ba
import os
import os.path as osp
import shutil
import time
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import Hook, obj_from_dict
from mmcv.parallel import scatter, collate
from pycocotools.cocoeval import COCOeval
......@@ -29,43 +28,14 @@ class DistEvalHook(Hook):
'dataset must be a Dataset object or a dict, not {}'.format(
type(dataset)))
self.interval = interval
self.lock_dir = None
def _barrier(self, rank, world_size):
"""Due to some issues with `torch.distributed.barrier()`, we have to
implement this ugly barrier function.
"""
if rank == 0:
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
while not (osp.exists(tmp)):
time.sleep(1)
for i in range(1, world_size):
tmp = osp.join(self.lock_dir, '{}.pkl'.format(i))
os.remove(tmp)
else:
tmp = osp.join(self.lock_dir, '{}.pkl'.format(rank))
mmcv.dump([], tmp)
while osp.exists(tmp):
time.sleep(1)
def before_run(self, runner):
self.lock_dir = osp.join(runner.work_dir, '.lock_map_hook')
if runner.rank == 0:
if osp.exists(self.lock_dir):
shutil.rmtree(self.lock_dir)
mmcv.mkdir_or_exist(self.lock_dir)
def after_run(self, runner):
if runner.rank == 0:
shutil.rmtree(self.lock_dir)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
runner.model.eval()
results = [None for _ in range(len(self.dataset))]
prog_bar = mmcv.ProgressBar(len(self.dataset))
if runner.rank == 0:
prog_bar = mmcv.ProgressBar(len(self.dataset))
for idx in range(runner.rank, len(self.dataset), runner.world_size):
data = self.dataset[idx]
data_gpu = scatter(
......@@ -79,12 +49,13 @@ class DistEvalHook(Hook):
results[idx] = result
batch_size = runner.world_size
for _ in range(batch_size):
prog_bar.update()
if runner.rank == 0:
for _ in range(batch_size):
prog_bar.update()
if runner.rank == 0:
print('\n')
self._barrier(runner.rank, runner.world_size)
dist.barrier()
for i in range(1, runner.world_size):
tmp_file = osp.join(runner.work_dir, 'temp_{}.pkl'.format(i))
tmp_results = mmcv.load(tmp_file)
......@@ -96,8 +67,8 @@ class DistEvalHook(Hook):
tmp_file = osp.join(runner.work_dir,
'temp_{}.pkl'.format(runner.rank))
mmcv.dump(results, tmp_file)
self._barrier(runner.rank, runner.world_size)
self._barrier(runner.rank, runner.world_size)
dist.barrier()
dist.barrier()
def evaluate(self):
raise NotImplementedError
......@@ -179,7 +150,13 @@ class CocoDistEvalmAPHook(DistEvalHook):
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
field = '{}_mAP'.format(res_type)
runner.log_buffer.output[field] = cocoEval.stats[0]
metrics = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l']
for i in range(len(metrics)):
key = '{}_{}'.format(res_type, metrics[i])
val = float('{:.3f}'.format(cocoEval.stats[i]))
runner.log_buffer.output[key] = val
runner.log_buffer.output['{}_mAP_copypaste'.format(res_type)] = (
'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
'{ap[4]:.3f} {ap[5]:.3f}').format(ap=cocoEval.stats[:6])
runner.log_buffer.ready = True
os.remove(tmp_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment