Unverified Commit 2e6c8ec8 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Refactor logger hooks (#605)

* Refactor tags for consistency

* Fix missing runner

* Fix missing runner

* Fix missing runner

* Fix missing runner

* Fix momentum runner hook inner iter

* Fix tests

* pre-commit run
parent 23b2bdbf
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from ..hook import Hook from ..hook import Hook
...@@ -31,6 +35,104 @@ class LoggerHook(Hook): ...@@ -31,6 +35,104 @@ class LoggerHook(Hook):
def log(self, runner): def log(self, runner):
pass pass
@staticmethod
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
def get_mode(self, runner):
if runner.mode == 'train':
if 'time' in runner.log_buffer.output:
mode = 'train'
else:
mode = 'val'
elif runner.mode == 'val':
mode = 'val'
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return mode
def get_epoch(self, runner):
if runner.mode == 'train':
epoch = runner.epoch + 1
elif runner.mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def get_iter(self, runner):
if self.by_epoch:
current_iter = runner.inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter
def get_step(self, runner):
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)
def get_lr_tags(self, runner):
tags = {}
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
return tags
def get_momentum_tags(self, runner):
tags = {}
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
return tags
def get_loggable_tags(self,
runner,
allow_scalar=True,
allow_text=False,
tags_to_skip=('time', 'data_time')):
tags = {}
for var, val in runner.log_buffer.output.items():
if var in tags_to_skip:
continue
if self.is_scalar(val) and not allow_scalar:
continue
if isinstance(val, str) and not allow_text:
continue
tag = f'{var}/{self.get_mode(runner)}'
tags[tag] = val
tags.update(self.get_lr_tags(runner))
tags.update(self.get_momentum_tags(runner))
return tags
def before_run(self, runner): def before_run(self, runner):
for hook in runner.hooks[::-1]: for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook): if isinstance(hook, LoggerHook):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -69,16 +67,9 @@ class MlflowLoggerHook(LoggerHook): ...@@ -69,16 +67,9 @@ class MlflowLoggerHook(LoggerHook):
@master_only @master_only
def log(self, runner): def log(self, runner):
metrics = {} tags = self.get_loggable_tags(runner)
for var, val in runner.log_buffer.output.items(): if tags:
if var in ['time', 'data_time']: self.mlflow.log_metrics(tags, step=self.get_step(runner))
continue
tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
self.mlflow.log_metrics(metrics, step=runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import json import json
import numbers
import os import os
import os.path as osp import os.path as osp
import numpy as np
import torch
import yaml import yaml
import mmcv import mmcv
...@@ -14,27 +11,6 @@ from ..hook import HOOKS ...@@ -14,27 +11,6 @@ from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
@HOOKS.register_module() @HOOKS.register_module()
class PaviLoggerHook(LoggerHook): class PaviLoggerHook(LoggerHook):
...@@ -82,38 +58,10 @@ class PaviLoggerHook(LoggerHook): ...@@ -82,38 +58,10 @@ class PaviLoggerHook(LoggerHook):
@master_only @master_only
def log(self, runner): def log(self, runner):
tags = {} tags = self.get_loggable_tags(runner)
for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
if tags: if tags:
if runner.mode == 'val': self.writer.add_scalars(
mode = runner.mode self.get_mode(runner), tags, self.get_step(runner))
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
epoch = runner.epoch + 1
if mode == 'val' and self.by_epoch:
self.writer.add_scalars(mode, tags, epoch)
else:
self.writer.add_scalars(mode, tags, runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -43,32 +43,12 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -43,32 +43,12 @@ class TensorboardLoggerHook(LoggerHook):
@master_only @master_only
def log(self, runner): def log(self, runner):
for var in runner.log_buffer.output: tags = self.get_loggable_tags(runner, allow_text=True)
if var in ['time', 'data_time']: for tag, val in tags.items():
continue if isinstance(val, str):
tag = f'{var}/{runner.mode}' self.writer.add_text(tag, val, self.get_step(runner))
record = runner.log_buffer.output[var]
if isinstance(record, str):
self.writer.add_text(tag, record, runner.iter)
else: else:
self.writer.add_scalar(tag, runner.log_buffer.output[var], self.writer.add_scalar(tag, val, self.get_step(runner))
runner.iter)
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
self.writer.add_scalar(f'learning_rate/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('learning_rate', lrs[0], runner.iter)
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
self.writer.add_scalar(f'momentum/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('momentum', momentums[0], runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -141,25 +141,11 @@ class TextLoggerHook(LoggerHook): ...@@ -141,25 +141,11 @@ class TextLoggerHook(LoggerHook):
return items return items
def log(self, runner): def log(self, runner):
log_dict = OrderedDict() log_dict = OrderedDict(
mode=self.get_mode(runner),
if runner.mode == 'train': epoch=self.get_epoch(runner),
log_dict['mode'] = 'train' if 'time' in runner.log_buffer.output \ iter=self.get_iter(runner))
else 'val'
log_dict['epoch'] = runner.epoch + 1
elif runner.mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
log_dict['mode'] = 'val'
log_dict['epoch'] = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
if self.by_epoch:
log_dict['iter'] = runner.inner_iter + 1
else:
log_dict['iter'] = runner.iter + 1
# only record lr of the first param group # only record lr of the first param group
cur_lr = runner.current_lr() cur_lr = runner.current_lr()
if isinstance(cur_lr, list): if isinstance(cur_lr, list):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numbers
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -39,17 +37,9 @@ class WandbLoggerHook(LoggerHook): ...@@ -39,17 +37,9 @@ class WandbLoggerHook(LoggerHook):
@master_only @master_only
def log(self, runner): def log(self, runner):
metrics = {} tags = self.get_loggable_tags(runner)
for var, val in runner.log_buffer.output.items(): if tags:
if var in ['time', 'data_time']: self.wandb.log(tags, step=self.get_step(runner))
continue
tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
if metrics:
self.wandb.log(metrics, step=runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -160,15 +160,15 @@ def test_momentum_runner_hook(): ...@@ -160,15 +160,15 @@ def test_momentum_runner_hook():
call('train', { call('train', {
'learning_rate': 0.01999999999999999, 'learning_rate': 0.01999999999999999,
'momentum': 0.95 'momentum': 0.95
}, 0), }, 1),
call('train', { call('train', {
'learning_rate': 0.2, 'learning_rate': 0.2,
'momentum': 0.85 'momentum': 0.85
}, 4), }, 5),
call('train', { call('train', {
'learning_rate': 0.155, 'learning_rate': 0.155,
'momentum': 0.875 'momentum': 0.875
}, 6), }, 7),
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -211,15 +211,15 @@ def test_cosine_runner_hook(): ...@@ -211,15 +211,15 @@ def test_cosine_runner_hook():
call('train', { call('train', {
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, 0), }, 1),
call('train', { call('train', {
'learning_rate': 0.01, 'learning_rate': 0.01,
'momentum': 0.97 'momentum': 0.97
}, 5), }, 6),
call('train', { call('train', {
'learning_rate': 0.0004894348370484647, 'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032 'momentum': 0.9890211303259032
}, 9) }, 10)
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -289,15 +289,15 @@ def test_cosine_restart_lr_update_hook(): ...@@ -289,15 +289,15 @@ def test_cosine_restart_lr_update_hook():
call('train', { call('train', {
'learning_rate': 0.01, 'learning_rate': 0.01,
'momentum': 0.95 'momentum': 0.95
}, 0), }, 1),
call('train', { call('train', {
'learning_rate': 0.0, 'learning_rate': 0.0,
'momentum': 0.95 'momentum': 0.95
}, 5), }, 6),
call('train', { call('train', {
'learning_rate': 0.0009549150281252633, 'learning_rate': 0.0009549150281252633,
'momentum': 0.95 'momentum': 0.95
}, 9) }, 10)
] ]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True) hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
...@@ -320,7 +320,7 @@ def test_mlflow_hook(log_model): ...@@ -320,7 +320,7 @@ def test_mlflow_hook(log_model):
{ {
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, step=5) }, step=1)
if log_model: if log_model:
hook.mlflow_pytorch.log_model.assert_called_with( hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models') runner.model, 'models')
...@@ -343,7 +343,7 @@ def test_wandb_hook(): ...@@ -343,7 +343,7 @@ def test_wandb_hook():
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, },
step=5) step=1)
hook.wandb.join.assert_called_with() hook.wandb.join.assert_called_with()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment