Unverified Commit ac92a111 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

`DvcliveLoggerHook` updates to work with `DVC` (#1208)



* Updates to work with DVC

* Update docstrings

* Updated test

* Updated DVCLiveLoggerHook

* Fix name

* Added missing next_step call

* Fix expected call

* Implicit next_step

* Suggestions from review

* Update test_hooks.py

* Updated to last dvclive version

* Cleaned docstring

* Update mmcv/runner/hooks/logger/dvclive.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update dvclive.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent fb486b96
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -11,48 +13,62 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -11,48 +13,62 @@ class DvcliveLoggerHook(LoggerHook):
It requires `dvclive`_ to be installed. It requires `dvclive`_ to be installed.
Args: Args:
path (str): Directory where dvclive will write TSV log files. model_file (str):
Default None.
If not None, after each epoch the model will
be saved to {model_file}.
interval (int): Logging interval (every k iterations). interval (int): Logging interval (every k iterations).
Default 10. Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`. if less than `interval`.
Default: True. Default: True.
reset_flag (bool): Whether to clear the output buffer after logging. reset_flag (bool): Whether to clear the output buffer after logging.
Default: True. Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. by_epoch (bool): Whether EpochBasedRunner is used.
Default: True. Default: True.
kwargs:
Arguments for instantiating `Live`_
.. _dvclive: .. _dvclive:
https://dvc.org/doc/dvclive https://dvc.org/doc/dvclive
.. _Live:
https://dvc.org/doc/dvclive/api-reference/live#parameters
""" """
def __init__(self, def __init__(self,
path, model_file=None,
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=True, reset_flag=False,
by_epoch=True): by_epoch=True,
**kwargs):
super(DvcliveLoggerHook, self).__init__(interval, ignore_last, super().__init__(interval, ignore_last, reset_flag, by_epoch)
reset_flag, by_epoch) self.model_file = model_file
self.path = path self.import_dvclive(**kwargs)
self.import_dvclive()
def import_dvclive(self): def import_dvclive(self, **kwargs):
try: try:
import dvclive from dvclive import Live
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please run "pip install dvclive" to install dvclive') 'Please run "pip install dvclive" to install dvclive')
self.dvclive = dvclive self.dvclive = Live(**kwargs)
@master_only
def before_run(self, runner):
self.dvclive.init(self.path)
@master_only @master_only
def log(self, runner): def log(self, runner):
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
self.dvclive.set_step(self.get_iter(runner))
for k, v in tags.items(): for k, v in tags.items():
self.dvclive.log(k, v, step=self.get_iter(runner)) self.dvclive.log(k, v)
@master_only
def after_train_epoch(self, runner):
super().after_train_epoch(runner)
if self.model_file is not None:
runner.save_checkpoint(
Path(self.model_file).parent,
filename_tmpl=Path(self.model_file).name,
create_symlink=False,
)
...@@ -1226,21 +1226,37 @@ def test_neptune_hook(): ...@@ -1226,21 +1226,37 @@ def test_neptune_hook():
hook.run.stop.assert_called_with() hook.run.stop.assert_called_with()
def test_dvclive_hook(tmp_path): def test_dvclive_hook():
sys.modules['dvclive'] = MagicMock() sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
(tmp_path / 'dvclive').mkdir() hook = DvcliveLoggerHook()
hook = DvcliveLoggerHook(str(tmp_path / 'dvclive')) dvclive_mock = hook.dvclive
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)]) runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive')) dvclive_mock.set_step.assert_called_with(6)
hook.dvclive.log.assert_called_with('momentum', 0.95, step=6) dvclive_mock.log.assert_called_with('momentum', 0.95)
hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)
def test_dvclive_hook_model_file(tmp_path):
sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner()
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
runner.register_hook(hook)
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 2)))
runner.run([loader, loader], [('train', 1), ('val', 1)])
assert osp.exists(osp.join(runner.work_dir, 'model.pth'))
shutil.rmtree(runner.work_dir)
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner', def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment