Unverified Commit 9709ff3f authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Enhancement] Add a new argument define_metric in wandb hook (#2237)



* wandb define_metric

* add test and some fix based on mmengine PR

* fix test

* add summary warnings

* Update mmcv/runner/hooks/logger/wandb.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/runner/hooks/logger/wandb.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent ff189047
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from mmcv.utils import scandir from mmcv.utils import scandir
...@@ -43,6 +44,17 @@ class WandbLoggerHook(LoggerHook): ...@@ -43,6 +44,17 @@ class WandbLoggerHook(LoggerHook):
``out_suffix`` will be uploaded to wandb. ``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py'). Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.` `New in version 1.4.3.`
define_metric_cfg (dict, optional): A dict of metrics and summaries for
wandb.define_metric. The key is metric and the value is summary.
The summary should be in ["min", "max", "mean" ,"best", "last",
"none"].
For example, if setting
``define_metric_cfg={'coco/bbox_mAP': 'max'}``, the maximum value
of ``coco/bbox_mAP`` will be logged on wandb UI. See
`wandb docs <https://docs.wandb.ai/ref/python/run#define_metric>`_
for details.
Defaults to None.
`New in version 1.6.3.`
.. _wandb: .. _wandb:
https://docs.wandb.ai https://docs.wandb.ai
...@@ -57,7 +69,8 @@ class WandbLoggerHook(LoggerHook): ...@@ -57,7 +69,8 @@ class WandbLoggerHook(LoggerHook):
by_epoch: bool = True, by_epoch: bool = True,
with_step: bool = True, with_step: bool = True,
log_artifact: bool = True, log_artifact: bool = True,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')): out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
define_metric_cfg: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
...@@ -65,6 +78,7 @@ class WandbLoggerHook(LoggerHook): ...@@ -65,6 +78,7 @@ class WandbLoggerHook(LoggerHook):
self.with_step = with_step self.with_step = with_step
self.log_artifact = log_artifact self.log_artifact = log_artifact
self.out_suffix = out_suffix self.out_suffix = out_suffix
self.define_metric_cfg = define_metric_cfg
def import_wandb(self) -> None: def import_wandb(self) -> None:
try: try:
...@@ -83,6 +97,15 @@ class WandbLoggerHook(LoggerHook): ...@@ -83,6 +97,15 @@ class WandbLoggerHook(LoggerHook):
self.wandb.init(**self.init_kwargs) # type: ignore self.wandb.init(**self.init_kwargs) # type: ignore
else: else:
self.wandb.init() # type: ignore self.wandb.init() # type: ignore
summary_choice = ['min', 'max', 'mean', 'best', 'last', 'none']
if self.define_metric_cfg is not None:
for metric, summary in self.define_metric_cfg.items():
if summary not in summary_choice:
warnings.warn(
f'summary should be in {summary_choice}. '
f'metric={metric}, summary={summary} will be skipped.')
self.wandb.define_metric( # type: ignore
metric, summary=summary)
@master_only @master_only
def log(self, runner) -> None: def log(self, runner) -> None:
......
...@@ -1606,7 +1606,8 @@ def test_segmind_hook(): ...@@ -1606,7 +1606,8 @@ def test_segmind_hook():
def test_wandb_hook(): def test_wandb_hook():
sys.modules['wandb'] = MagicMock() sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
hook = WandbLoggerHook(log_artifact=True) hook = WandbLoggerHook(
log_artifact=True, define_metric_cfg={'val/loss': 'min'})
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
...@@ -1615,6 +1616,7 @@ def test_wandb_hook(): ...@@ -1615,6 +1616,7 @@ def test_wandb_hook():
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
hook.wandb.define_metric.assert_called_with('val/loss', summary='min')
hook.wandb.log.assert_called_with({ hook.wandb.log.assert_called_with({
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment