Unverified Commit bdd70221 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Add DvcliveLoggerHook (#1075)

* Add dvclive logger hook

* Move docstring to class

* docstring updates
parent d212bd53
...@@ -10,11 +10,11 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, ...@@ -10,11 +10,11 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
from .epoch_based_runner import EpochBasedRunner, Runner from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook, DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook, Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook,
MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook, LrUpdaterHook, MlflowLoggerHook, NeptuneLoggerHook,
PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook, OptimizerHook, PaviLoggerHook, SyncBuffersHook,
TextLoggerHook, WandbLoggerHook) TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .iter_based_runner import IterBasedRunner, IterLoader from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
...@@ -29,11 +29,11 @@ __all__ = [ ...@@ -29,11 +29,11 @@ __all__ = [
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook', 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
'_load_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
'save_checkpoint', 'Priority', 'get_priority', 'get_host_info', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_time_str', 'obj_from_dict', 'init_dist', 'get_dist_info', 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'master_only', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'DefaultOptimizerConstructor', 'build_optimizer', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor', 'IterLoader', 'set_random_seed', 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
......
...@@ -5,9 +5,9 @@ from .ema import EMAHook ...@@ -5,9 +5,9 @@ from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, NeptuneLoggerHook, from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook, NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
WandbLoggerHook) TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook from .memory import EmptyCacheHook
from .momentum_updater import MomentumUpdaterHook from .momentum_updater import MomentumUpdaterHook
...@@ -21,6 +21,7 @@ __all__ = [ ...@@ -21,6 +21,7 @@ __all__ = [
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook', 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook' 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
'DistEvalHook', 'ProfilerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook from .base import LoggerHook
from .dvclive import DvcliveLoggerHook
from .mlflow import MlflowLoggerHook from .mlflow import MlflowLoggerHook
from .neptune import NeptuneLoggerHook from .neptune import NeptuneLoggerHook
from .pavi import PaviLoggerHook from .pavi import PaviLoggerHook
...@@ -10,5 +11,5 @@ from .wandb import WandbLoggerHook ...@@ -10,5 +11,5 @@ from .wandb import WandbLoggerHook
__all__ = [ __all__ = [
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
'NeptuneLoggerHook' 'NeptuneLoggerHook', 'DvcliveLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class DvcliveLoggerHook(LoggerHook):
"""Class to log metrics with dvclive.
It requires `dvclive`_ to be installed.
Args:
path (str): Directory where dvclive will write TSV log files.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: True.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
.. _dvclive:
https://dvc.org/doc/dvclive
"""
def __init__(self,
path,
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):
super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.path = path
self.import_dvclive()
def import_dvclive(self):
try:
import dvclive
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = dvclive
@master_only
def before_run(self, runner):
self.dvclive.init(self.path)
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
for k, v in tags.items():
self.dvclive.log(k, v, step=self.get_iter(runner))
...@@ -18,9 +18,9 @@ import torch.nn as nn ...@@ -18,9 +18,9 @@ import torch.nn as nn
from torch.nn.init import constant_ from torch.nn.init import constant_
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook,
WandbLoggerHook, build_runner) PaviLoggerHook, WandbLoggerHook, build_runner)
from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook, CyclicLrUpdaterHook,
...@@ -920,6 +920,7 @@ def test_neptune_hook(): ...@@ -920,6 +920,7 @@ def test_neptune_hook():
sys.modules['neptune.new'] = MagicMock() sys.modules['neptune.new'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
hook = NeptuneLoggerHook() hook = NeptuneLoggerHook()
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
...@@ -931,6 +932,23 @@ def test_neptune_hook(): ...@@ -931,6 +932,23 @@ def test_neptune_hook():
hook.run.stop.assert_called_with() hook.run.stop.assert_called_with()
def test_dvclive_hook(tmp_path):
sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner()
(tmp_path / 'dvclive').mkdir()
hook = DvcliveLoggerHook(str(tmp_path / 'dvclive'))
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive'))
hook.dvclive.log.assert_called_with('momentum', 0.95, step=6)
hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner', def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1, max_epochs=1,
max_iters=None, max_iters=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment