Unverified Commit 2e6c8ec8 authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Refactor logger hooks (#605)

* Refactor tags for consistency

* Fix missing runner

* Fix missing runner

* Fix missing runner

* Fix missing runner

* Fix momentum runner hook inner iter

* Fix tests

* pre-commit run
parent 23b2bdbf
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from ..hook import Hook
......@@ -31,6 +35,104 @@ class LoggerHook(Hook):
def log(self, runner):
pass
@staticmethod
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
def get_mode(self, runner):
if runner.mode == 'train':
if 'time' in runner.log_buffer.output:
mode = 'train'
else:
mode = 'val'
elif runner.mode == 'val':
mode = 'val'
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return mode
def get_epoch(self, runner):
if runner.mode == 'train':
epoch = runner.epoch + 1
elif runner.mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def get_iter(self, runner):
if self.by_epoch:
current_iter = runner.inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter
def get_step(self, runner):
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)
def get_lr_tags(self, runner):
tags = {}
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
return tags
def get_momentum_tags(self, runner):
tags = {}
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
return tags
def get_loggable_tags(self,
runner,
allow_scalar=True,
allow_text=False,
tags_to_skip=('time', 'data_time')):
tags = {}
for var, val in runner.log_buffer.output.items():
if var in tags_to_skip:
continue
if self.is_scalar(val) and not allow_scalar:
continue
if isinstance(val, str) and not allow_text:
continue
tag = f'{var}/{self.get_mode(runner)}'
tags[tag] = val
tags.update(self.get_lr_tags(runner))
tags.update(self.get_momentum_tags(runner))
return tags
def before_run(self, runner):
for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook):
......
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -69,16 +67,9 @@ class MlflowLoggerHook(LoggerHook):
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
self.mlflow.log_metrics(metrics, step=runner.iter)
tags = self.get_loggable_tags(runner)
if tags:
self.mlflow.log_metrics(tags, step=self.get_step(runner))
@master_only
def after_run(self, runner):
......
# Copyright (c) Open-MMLab. All rights reserved.
import json
import numbers
import os
import os.path as osp
import numpy as np
import torch
import yaml
import mmcv
......@@ -14,27 +11,6 @@ from ..hook import HOOKS
from .base import LoggerHook
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
@HOOKS.register_module()
class PaviLoggerHook(LoggerHook):
......@@ -82,38 +58,10 @@ class PaviLoggerHook(LoggerHook):
@master_only
def log(self, runner):
tags = {}
for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
tags = self.get_loggable_tags(runner)
if tags:
if runner.mode == 'val':
mode = runner.mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
epoch = runner.epoch + 1
if mode == 'val' and self.by_epoch:
self.writer.add_scalars(mode, tags, epoch)
else:
self.writer.add_scalars(mode, tags, runner.iter)
self.writer.add_scalars(
self.get_mode(runner), tags, self.get_step(runner))
@master_only
def after_run(self, runner):
......
......@@ -43,32 +43,12 @@ class TensorboardLoggerHook(LoggerHook):
@master_only
def log(self, runner):
for var in runner.log_buffer.output:
if var in ['time', 'data_time']:
continue
tag = f'{var}/{runner.mode}'
record = runner.log_buffer.output[var]
if isinstance(record, str):
self.writer.add_text(tag, record, runner.iter)
tags = self.get_loggable_tags(runner, allow_text=True)
for tag, val in tags.items():
if isinstance(val, str):
self.writer.add_text(tag, val, self.get_step(runner))
else:
self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter)
# add learning rate
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
self.writer.add_scalar(f'learning_rate/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('learning_rate', lrs[0], runner.iter)
# add momentum
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
self.writer.add_scalar(f'momentum/{name}', value[0],
runner.iter)
else:
self.writer.add_scalar('momentum', momentums[0], runner.iter)
self.writer.add_scalar(tag, val, self.get_step(runner))
@master_only
def after_run(self, runner):
......
......@@ -141,25 +141,11 @@ class TextLoggerHook(LoggerHook):
return items
def log(self, runner):
log_dict = OrderedDict()
if runner.mode == 'train':
log_dict['mode'] = 'train' if 'time' in runner.log_buffer.output \
else 'val'
log_dict['epoch'] = runner.epoch + 1
elif runner.mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
log_dict['mode'] = 'val'
log_dict['epoch'] = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=self.get_iter(runner))
if self.by_epoch:
log_dict['iter'] = runner.inner_iter + 1
else:
log_dict['iter'] = runner.iter + 1
# only record lr of the first param group
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
......
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -39,17 +37,9 @@ class WandbLoggerHook(LoggerHook):
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = f'{var}/{runner.mode}'
if isinstance(val, numbers.Number):
metrics[tag] = val
metrics['learning_rate'] = runner.current_lr()[0]
metrics['momentum'] = runner.current_momentum()[0]
if metrics:
self.wandb.log(metrics, step=runner.iter)
tags = self.get_loggable_tags(runner)
if tags:
self.wandb.log(tags, step=self.get_step(runner))
@master_only
def after_run(self, runner):
......
......@@ -160,15 +160,15 @@ def test_momentum_runner_hook():
call('train', {
'learning_rate': 0.01999999999999999,
'momentum': 0.95
}, 0),
}, 1),
call('train', {
'learning_rate': 0.2,
'momentum': 0.85
}, 4),
}, 5),
call('train', {
'learning_rate': 0.155,
'momentum': 0.875
}, 6),
}, 7),
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
......@@ -211,15 +211,15 @@ def test_cosine_runner_hook():
call('train', {
'learning_rate': 0.02,
'momentum': 0.95
}, 0),
}, 1),
call('train', {
'learning_rate': 0.01,
'momentum': 0.97
}, 5),
}, 6),
call('train', {
'learning_rate': 0.0004894348370484647,
'momentum': 0.9890211303259032
}, 9)
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
......@@ -289,15 +289,15 @@ def test_cosine_restart_lr_update_hook():
call('train', {
'learning_rate': 0.01,
'momentum': 0.95
}, 0),
}, 1),
call('train', {
'learning_rate': 0.0,
'momentum': 0.95
}, 5),
}, 6),
call('train', {
'learning_rate': 0.0009549150281252633,
'momentum': 0.95
}, 9)
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
......@@ -320,7 +320,7 @@ def test_mlflow_hook(log_model):
{
'learning_rate': 0.02,
'momentum': 0.95
}, step=5)
}, step=1)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
......@@ -343,7 +343,7 @@ def test_wandb_hook():
'learning_rate': 0.02,
'momentum': 0.95
},
step=5)
step=1)
hook.wandb.join.assert_called_with()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment