Unverified Commit 40518322 authored by fatih's avatar fatih Committed by GitHub
Browse files

add artifact logging to wandb hook (#1616)



* add artifact logging to wandb hook

* upload artifacts wiwth only specified suffix

* update docstring

* Update mmcv/runner/hooks/logger/wandb.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix linting

* add tests for wandb artifact logging

* remove redundant lines

* fix wandb tests

* init `WandbLoggerHook` with `log_artifact=True` in tests

* remove redundant lines from wandb tests

* add docstring for `with_step`
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 8abb3b29
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmcv.utils import scandir
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -6,6 +9,33 @@ from .base import LoggerHook ...@@ -6,6 +9,33 @@ from .base import LoggerHook
@HOOKS.register_module() @HOOKS.register_module()
class WandbLoggerHook(LoggerHook): class WandbLoggerHook(LoggerHook):
"""Class to log metrics with wandb.
It requires `wandb` to be installed.
Args:
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
with_step (bool): If True, the step will be logged from
``self.get_iters``. Otherwise, step will not be logged.
Default: True.
log_artifact (bool): If True, artifacts in {work_dir} will be uploaded
to wandb after training ends.
Default: True
`New in version 1.4.3.`
out_suffix (str or tuple[str], optional): Those filenames ending with
``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.`
"""
def __init__(self, def __init__(self,
init_kwargs=None, init_kwargs=None,
...@@ -14,13 +44,17 @@ class WandbLoggerHook(LoggerHook): ...@@ -14,13 +44,17 @@ class WandbLoggerHook(LoggerHook):
reset_flag=False, reset_flag=False,
commit=True, commit=True,
by_epoch=True, by_epoch=True,
with_step=True): with_step=True,
log_artifact=True,
out_suffix=('.log.json', '.log', '.py')):
super(WandbLoggerHook, self).__init__(interval, ignore_last, super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch) reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.commit = commit self.commit = commit
self.with_step = with_step self.with_step = with_step
self.log_artifact = log_artifact
self.out_suffix = out_suffix
def import_wandb(self): def import_wandb(self):
try: try:
...@@ -53,4 +87,11 @@ class WandbLoggerHook(LoggerHook): ...@@ -53,4 +87,11 @@ class WandbLoggerHook(LoggerHook):
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
if self.log_artifact:
wandb_artifact = self.wandb.Artifact(
name='artifacts', type='model')
for filename in scandir(runner.work_dir, self.out_suffix, True):
local_filepath = osp.join(runner.work_dir, filename)
wandb_artifact.add_file(local_filepath)
self.wandb.log_artifact(wandb_artifact)
self.wandb.join() self.wandb.join()
...@@ -1192,11 +1192,12 @@ def test_mlflow_hook(log_model): ...@@ -1192,11 +1192,12 @@ def test_mlflow_hook(log_model):
def test_wandb_hook(): def test_wandb_hook():
sys.modules['wandb'] = MagicMock() sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
hook = WandbLoggerHook() hook = WandbLoggerHook(log_artifact=True)
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)]) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
...@@ -1206,6 +1207,7 @@ def test_wandb_hook(): ...@@ -1206,6 +1207,7 @@ def test_wandb_hook():
}, },
step=6, step=6,
commit=True) commit=True)
hook.wandb.log_artifact.assert_called()
hook.wandb.join.assert_called_with() hook.wandb.join.assert_called_with()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment