Unverified Commit fd066152 authored by Jiangmiao Pang's avatar Jiangmiao Pang Committed by GitHub
Browse files

Update PaviLoggerHook (#202)

* add is_scalar in PaviLoggerHook and add unittest

* add docstrings
parent 120c6a64
# Copyright (c) Open-MMLab. All rights reserved.
import numbers
import os.path as osp
import numpy as np
import torch
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
@HOOKS.register_module
class PaviLoggerHook(LoggerHook):
......@@ -40,6 +65,15 @@ class PaviLoggerHook(LoggerHook):
if self.add_graph:
self.writer.add_graph(runner.model)
@master_only
def log(self, runner):
tags = {}
for tag, val in runner.log_buffer.output.items():
if tag not in ['time', 'data_time'] and is_scalar(val):
tags[tag] = val
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)
@master_only
def after_run(self, runner):
if self.add_last_ckpt:
......@@ -48,13 +82,3 @@ class PaviLoggerHook(LoggerHook):
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=runner.iter)
@master_only
def log(self, runner):
tags = {}
for tag, val in runner.log_buffer.output.items():
if tag in ['time', 'data_time']:
continue
tags[tag] = val
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)
import os.path as osp
import sys
import warnings
from mock import MagicMock
import mmcv.runner
def test_pavi_hook():
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
except ImportError:
warnings.warn('Skipping test_pavi_hook in the absense of torch')
return
sys.modules['pavi'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'loss': 2.333
},
'num_samples': 5
})
hook = mmcv.runner.hooks.PaviLoggerHook(
add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', {'loss': 2.333}, 5)
hook.writer.add_snapshot_file.assert_called_with(
tag='data',
snapshot_file_path=osp.join(work_dir, 'latest.pth'),
iteration=5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment