Unverified Commit d5f190d1 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Add MlflowLoggerHook (#221)

* Add MLflowLoggerHook

* Add MLflowLoggerHook to __all__

* Update name

* Fix tracking.MlflowClient setup

* Fix log_metric

* Fix mlflow_pytorch import

* Handle active_run

* Fix self.mlflow reference

* Simplify using high level API

* Fix set_experiment

* Add only_if_torch_available decorator and test_mlflow_hook

* Add missing import in hooks

* Fix torch available check

* Patch mlflow.pytorch in test

* Parametrize log_model

* Fix log_model parametrize

* Add docstring

* Move wand patch

* Fix flake8

* Add regression test for non numeric metric

* Only log numbers

* Rename experiment_name-> exp_name

* Remove pytest skip
parent 728b88df
...@@ -3,8 +3,8 @@ from .checkpoint import CheckpointHook ...@@ -3,8 +3,8 @@ from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import (LoggerHook, PaviLoggerHook, TensorboardLoggerHook, from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
TextLoggerHook, WandbLoggerHook) TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook from .memory import EmptyCacheHook
from .optimizer import OptimizerHook from .optimizer import OptimizerHook
...@@ -13,6 +13,6 @@ from .sampler_seed import DistSamplerSeedHook ...@@ -13,6 +13,6 @@ from .sampler_seed import DistSamplerSeedHook
__all__ = [ __all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook',
'WandbLoggerHook' 'TensorboardLoggerHook', 'WandbLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook from .base import LoggerHook
from .mlflow import MlflowLoggerHook
from .pavi import PaviLoggerHook from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook from .text import TextLoggerHook
from .wandb import WandbLoggerHook from .wandb import WandbLoggerHook
__all__ = [ __all__ = [
'LoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'TextLoggerHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'WandbLoggerHook' 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class MlflowLoggerHook(LoggerHook):
def __init__(self,
exp_name=None,
tags=None,
log_model=True,
interval=10,
ignore_last=True,
reset_flag=True):
"""Class to log metrics and (optionally) a trained model to MLflow.
It requires `MLflow`_ to be installed.
Args:
exp_name (str, optional): Name of the experiment to be used.
Default None.
If not None, set the active experiment.
If experiment does not exist, an experiment with provided name
will be created.
tags (dict of str: str, optional): Tags for the current run.
Default None.
If not None, set tags for the current run.
log_model (bool, optional): Wheter to log an MLflow artifact.
Default True.
If True, log runner.model as an MLflow artifact
for the current run.
.. _MLflow:
https://www.mlflow.org/docs/latest/index.html
"""
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.import_mlflow()
self.exp_name = exp_name
self.tags = tags
self.log_model = log_model
def import_mlflow(self):
try:
import mlflow
import mlflow.pytorch as mlflow_pytorch
except ImportError:
raise ImportError(
'Please run "pip install mlflow" to install mlflow')
self.mlflow = mlflow
self.mlflow_pytorch = mlflow_pytorch
@master_only
def before_run(self, runner):
if self.exp_name is not None:
self.mlflow.set_experiment(self.exp_name)
if self.tags is not None:
self.mlflow.set_tags(self.tags)
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
if isinstance(val, numbers.Number):
metrics[tag] = val
self.mlflow.log_metrics(metrics, step=runner.iter)
@master_only
def after_run(self, runner):
if self.log_model:
self.mlflow_pytorch.log_model(runner.model, 'models')
import os.path as osp import os.path as osp
import sys import sys
import warnings
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv.runner import mmcv.runner
def test_pavi_hook(): def test_pavi_hook():
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
except ImportError:
warnings.warn('Skipping test_pavi_hook in the absense of torch')
return
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
model = nn.Linear(1, 1) model = nn.Linear(1, 1)
...@@ -40,3 +37,57 @@ def test_pavi_hook(): ...@@ -40,3 +37,57 @@ def test_pavi_hook():
tag='data', tag='data',
snapshot_file_path=osp.join(work_dir, 'latest.pth'), snapshot_file_path=osp.join(work_dir, 'latest.pth'),
iteration=5) iteration=5)
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
sys.modules['mlflow.pytorch'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
hook = mmcv.runner.hooks.MlflowLoggerHook(
exp_name='test', log_model=log_model)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with({'accuracy/val': 0.98}, step=5)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
else:
assert not hook.mlflow_pytorch.log_model.called
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
hook = mmcv.runner.hooks.WandbLoggerHook()
loader = DataLoader(torch.ones((5, 5)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({'accuracy/val': 0.98}, step=5)
hook.wandb.join.assert_called_with()
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp import os.path as osp
import sys
import tempfile import tempfile
import warnings import warnings
from unittest.mock import MagicMock
sys.modules['wandb'] = MagicMock()
def test_save_checkpoint(): def test_save_checkpoint():
...@@ -32,32 +28,3 @@ def test_save_checkpoint(): ...@@ -32,32 +28,3 @@ def test_save_checkpoint():
assert osp.realpath(latest_path) == osp.realpath(epoch1_path) assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
torch.load(latest_path) torch.load(latest_path)
def test_wandb_hook():
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
hook = mmcv.runner.hooks.WandbLoggerHook()
loader = DataLoader(torch.ones((5, 5)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({'accuracy/val': 0.98}, step=5)
hook.wandb.join.assert_called_with()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment