Unverified Commit 560719da authored by Yining Li's avatar Yining Li Committed by GitHub
Browse files

EvalHook uses case-insensitive key indicator matching and configurabl… (#1076)

* EvalHook uses case-insensitive key indicator matching and configurable test functions

* * fix docstring

* * move test_fn import into __init__
* configurable greater/less keys

* * update unittest
* update DistEvalHook

* fix comments and remove debug code

* support single greater/less key
parent 49a1d347
......@@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from mmcv.utils import is_seq_of
from .hook import Hook
......@@ -41,6 +42,16 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
greater_keys (List[str] | None, optional): Metric keys that will be
inferred by 'greater' comparison rule rule. If ``None``,
_default_greater_keys will be used. (default: ``None``)
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
......@@ -55,11 +66,11 @@ class EvalHook(Hook):
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
greater_keys = [
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
less_keys = ['loss']
_default_less_keys = ['loss']
def __init__(self,
dataloader,
......@@ -68,6 +79,9 @@ class EvalHook(Hook):
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
......@@ -95,6 +109,28 @@ class EvalHook(Hook):
self.eval_kwargs = eval_kwargs
self.initial_flag = True
if test_fn is None:
from mmcv.engine import single_gpu_test
self.test_fn = single_gpu_test
else:
self.test_fn = test_fn
if greater_keys is None:
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys
if less_keys is None:
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys
if self.save_best is not None:
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)
......@@ -103,7 +139,8 @@ class EvalHook(Hook):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
when the rule is not specific:
when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
......@@ -124,13 +161,19 @@ class EvalHook(Hook):
if rule is None:
if key_indicator != 'auto':
if key_indicator in self.greater_keys:
# `_lc` here means we use the lower case of keys for
# case-insensitive matching
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]
if key_indicator_lc in greater_keys:
rule = 'greater'
elif key_indicator in self.less_keys:
elif key_indicator_lc in less_keys:
rule = 'less'
elif any(key in key_indicator for key in self.greater_keys):
elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
elif any(key in key_indicator for key in self.less_keys):
elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
......@@ -181,8 +224,7 @@ class EvalHook(Hook):
if not self._should_evaluate(runner):
return
from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
......@@ -311,6 +353,10 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
......@@ -329,10 +375,18 @@ class DistEvalHook(EvalHook):
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
if test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test
super().__init__(
dataloader,
start=start,
......@@ -340,7 +394,11 @@ class DistEvalHook(EvalHook):
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
......@@ -367,8 +425,7 @@ class DistEvalHook(EvalHook):
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
from mmcv.engine import multi_gpu_test
results = multi_gpu_test(
results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
......
......@@ -84,8 +84,8 @@ def _build_iter_runner():
class EvalHook(BaseEvalHook):
greater_keys = ['acc', 'top']
less_keys = ['loss', 'loss_top']
_default_greater_keys = ['acc', 'top']
_default_less_keys = ['loss', 'loss_top']
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -273,6 +273,31 @@ def test_eval_hook():
assert runner.meta['hook_msgs']['best_score'] == 7
assert not osp.exists(old_ckpt_path)
# test EvalHook with customer test_fn and greater/less keys
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(
data_loader,
save_best='acc',
test_fn=mock.MagicMock(return_value={}),
greater_keys=[],
less_keys=['acc'])
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3
@patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment