Unverified Commit 560719da authored by Yining Li's avatar Yining Li Committed by GitHub
Browse files

EvalHook uses case-insensitive key indicator matching and configurabl… (#1076)

* EvalHook uses case-insensitive key indicator matching and configurable test functions

* * fix docstring

* * move test_fn import into __init__
* configurable greater/less keys

* * update unittest
* update DistEvalHook

* fix comments and remove debug code

* support single greater/less key
parent 49a1d347
...@@ -7,6 +7,7 @@ import torch.distributed as dist ...@@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.utils import is_seq_of
from .hook import Hook from .hook import Hook
...@@ -41,6 +42,16 @@ class EvalHook(Hook): ...@@ -41,6 +42,16 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will .etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None. be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None. Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
greater_keys (List[str] | None, optional): Metric keys that will be
inferred by 'greater' comparison rule rule. If ``None``,
_default_greater_keys will be used. (default: ``None``)
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of **eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset. the dataset.
...@@ -55,11 +66,11 @@ class EvalHook(Hook): ...@@ -55,11 +66,11 @@ class EvalHook(Hook):
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf} init_value_map = {'greater': -inf, 'less': inf}
greater_keys = [ _default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc' 'mAcc', 'aAcc'
] ]
less_keys = ['loss'] _default_less_keys = ['loss']
def __init__(self, def __init__(self,
dataloader, dataloader,
...@@ -68,6 +79,9 @@ class EvalHook(Hook): ...@@ -68,6 +79,9 @@ class EvalHook(Hook):
by_epoch=True, by_epoch=True,
save_best=None, save_best=None,
rule=None, rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
**eval_kwargs): **eval_kwargs):
if not isinstance(dataloader, DataLoader): if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, ' raise TypeError(f'dataloader must be a pytorch DataLoader, '
...@@ -95,6 +109,28 @@ class EvalHook(Hook): ...@@ -95,6 +109,28 @@ class EvalHook(Hook):
self.eval_kwargs = eval_kwargs self.eval_kwargs = eval_kwargs
self.initial_flag = True self.initial_flag = True
if test_fn is None:
from mmcv.engine import single_gpu_test
self.test_fn = single_gpu_test
else:
self.test_fn = test_fn
if greater_keys is None:
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys
if less_keys is None:
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys
if self.save_best is not None: if self.save_best is not None:
self.best_ckpt_path = None self.best_ckpt_path = None
self._init_rule(rule, self.save_best) self._init_rule(rule, self.save_best)
...@@ -103,7 +139,8 @@ class EvalHook(Hook): ...@@ -103,7 +139,8 @@ class EvalHook(Hook):
"""Initialize rule, key_indicator, comparison_func, and best score. """Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator Here is the rule to determine which rule is used for key indicator
when the rule is not specific: when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be 1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'. specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be 2. Or if the key indicator is in ``self.less_keys``, the rule will be
...@@ -124,13 +161,19 @@ class EvalHook(Hook): ...@@ -124,13 +161,19 @@ class EvalHook(Hook):
if rule is None: if rule is None:
if key_indicator != 'auto': if key_indicator != 'auto':
if key_indicator in self.greater_keys: # `_lc` here means we use the lower case of keys for
# case-insensitive matching
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]
if key_indicator_lc in greater_keys:
rule = 'greater' rule = 'greater'
elif key_indicator in self.less_keys: elif key_indicator_lc in less_keys:
rule = 'less' rule = 'less'
elif any(key in key_indicator for key in self.greater_keys): elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater' rule = 'greater'
elif any(key in key_indicator for key in self.less_keys): elif any(key in key_indicator_lc for key in less_keys):
rule = 'less' rule = 'less'
else: else:
raise ValueError(f'Cannot infer the rule for key ' raise ValueError(f'Cannot infer the rule for key '
...@@ -181,8 +224,7 @@ class EvalHook(Hook): ...@@ -181,8 +224,7 @@ class EvalHook(Hook):
if not self._should_evaluate(runner): if not self._should_evaluate(runner):
return return
from mmcv.engine import single_gpu_test results = self.test_fn(runner.model, self.dataloader)
results = single_gpu_test(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results) key_score = self.evaluate(runner, results)
if self.save_best: if self.save_best:
...@@ -311,6 +353,10 @@ class DistEvalHook(EvalHook): ...@@ -311,6 +353,10 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will .etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None. be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None. Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None. processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results. gpu_collect (bool): Whether to use gpu or cpu to collect results.
...@@ -329,10 +375,18 @@ class DistEvalHook(EvalHook): ...@@ -329,10 +375,18 @@ class DistEvalHook(EvalHook):
by_epoch=True, by_epoch=True,
save_best=None, save_best=None,
rule=None, rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
broadcast_bn_buffer=True, broadcast_bn_buffer=True,
tmpdir=None, tmpdir=None,
gpu_collect=False, gpu_collect=False,
**eval_kwargs): **eval_kwargs):
if test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test
super().__init__( super().__init__(
dataloader, dataloader,
start=start, start=start,
...@@ -340,7 +394,11 @@ class DistEvalHook(EvalHook): ...@@ -340,7 +394,11 @@ class DistEvalHook(EvalHook):
by_epoch=by_epoch, by_epoch=by_epoch,
save_best=save_best, save_best=save_best,
rule=rule, rule=rule,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
**eval_kwargs) **eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir self.tmpdir = tmpdir
self.gpu_collect = gpu_collect self.gpu_collect = gpu_collect
...@@ -367,8 +425,7 @@ class DistEvalHook(EvalHook): ...@@ -367,8 +425,7 @@ class DistEvalHook(EvalHook):
if tmpdir is None: if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook') tmpdir = osp.join(runner.work_dir, '.eval_hook')
from mmcv.engine import multi_gpu_test results = self.test_fn(
results = multi_gpu_test(
runner.model, runner.model,
self.dataloader, self.dataloader,
tmpdir=tmpdir, tmpdir=tmpdir,
......
...@@ -84,8 +84,8 @@ def _build_iter_runner(): ...@@ -84,8 +84,8 @@ def _build_iter_runner():
class EvalHook(BaseEvalHook): class EvalHook(BaseEvalHook):
greater_keys = ['acc', 'top'] _default_greater_keys = ['acc', 'top']
less_keys = ['loss', 'loss_top'] _default_less_keys = ['loss', 'loss_top']
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -273,6 +273,31 @@ def test_eval_hook(): ...@@ -273,6 +273,31 @@ def test_eval_hook():
assert runner.meta['hook_msgs']['best_score'] == 7 assert runner.meta['hook_msgs']['best_score'] == 7
assert not osp.exists(old_ckpt_path) assert not osp.exists(old_ckpt_path)
# test EvalHook with customer test_fn and greater/less keys
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(
data_loader,
save_best='acc',
test_fn=mock.MagicMock(return_value={}),
greater_keys=[],
less_keys=['acc'])
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3
@patch('mmcv.engine.single_gpu_test', MagicMock) @patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock) @patch('mmcv.engine.multi_gpu_test', MagicMock)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment