Unverified Commit 34f227ef authored by saurbhc's avatar saurbhc Committed by GitHub
Browse files

[Feature] Add SegmindLoggerHook (#1650)



* add SegmindLoggerHook

* update linting for segmind.py

* ran pre-commit

* add test_segmind_hook

- add SegmindLoggerHook import in:
  mmcv/runner/__init__.py
  mmcv/runner/hooks/__init__.py
  mmcv/runner/hooks/logger/__init__.py

* update test_segmind_hook

- Add Docstring to SegmindLoggerHook
- Use get_loggable_tags(...)

* update test_hooks.py & segmind.py

- mmcv/runner/hooks/logger/segmind.py
  moved docs from __init__ to class ...
  update ImportError line-indentation
  remove unwanted method
- tests/test_runner/test_hooks.py
  update assert_called_with only on hook.segmind_mlflow_log

* Update tests/test_runner/test_hooks.py

disable yapf on test_hooks.py imports
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/runner/hooks/logger/segmind.py

Update SegmindLoggerHook docstring
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>

* update before_run method in segmind.py

removed un-used statements

* updated imports in SegmindLoggerHook

* update SegmindLoggerHook

- code cleanup

* update SegmindLoggerHook

- add interval parameter in __init__ method

* update SegmindLoggerHook

- more arguments passes to __init__ method
  - interval
  - ignore_last
  - reset_flag
  - by_epoch

* Update mmcv/runner/hooks/logger/segmind.py
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update mmcv/runner/hooks/logger/segmind.py
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
parent 4e773820
......@@ -15,8 +15,9 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SyncBuffersHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
OptimizerHook, PaviLoggerHook, SegmindLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .hooks.lr_updater import StepLrUpdaterHook # noqa
from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
......@@ -60,5 +61,6 @@ __all__ = [
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor',
'SegmindLoggerHook'
]
......@@ -6,8 +6,8 @@ from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
......@@ -38,5 +38,6 @@ __all__ = [
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook'
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'SegmindLoggerHook'
]
......@@ -4,6 +4,7 @@ from .dvclive import DvcliveLoggerHook
from .mlflow import MlflowLoggerHook
from .neptune import NeptuneLoggerHook
from .pavi import PaviLoggerHook
from .segmind import SegmindLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from .wandb import WandbLoggerHook
......@@ -11,5 +12,5 @@ from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
'NeptuneLoggerHook', 'DvcliveLoggerHook'
'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class SegmindLoggerHook(LoggerHook):
"""Class to log metrics to Segmind.
It requires `Segmind`_ to be installed.
Args:
interval (int): Logging interval (every k iterations). Default: 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`. Default True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default False.
by_epoch (bool): Whether EpochBasedRunner is used. Default True.
.. _Segmind:
https://docs.segmind.com/python-library
"""
def __init__(self,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
super(SegmindLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.import_segmind()
def import_segmind(self):
try:
import segmind
except ImportError:
raise ImportError(
"Please run 'pip install segmind' to install segmind")
self.log_metrics = segmind.tracking.fluent.log_metrics
self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
# logging metrics to segmind
self.mlflow_log(
self.log_metrics, tags, step=runner.epoch, epoch=runner.epoch)
......@@ -22,12 +22,15 @@ from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
# yapf: disable
from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, IterTimerHook,
MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook,
PaviLoggerHook, WandbLoggerHook, build_runner)
PaviLoggerHook, SegmindLoggerHook, WandbLoggerHook,
build_runner)
# yapf: enable
from mmcv.runner.fp16_utils import auto_fp16
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
......@@ -1401,6 +1404,25 @@ def test_mlflow_hook(log_model):
assert not hook.mlflow_pytorch.log_model.called
def test_segmind_hook():
sys.modules['segmind'] = MagicMock()
runner = _build_demo_runner()
hook = SegmindLoggerHook()
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.mlflow_log.assert_called_with(
hook.log_metrics, {
'learning_rate': 0.02,
'momentum': 0.95
},
step=runner.epoch,
epoch=runner.epoch)
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment