"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "54a7ebb4ec5bbdb35110d3c44a00747c792ab1d3"
Commit 56971278 authored by Vladimir Iglovikov's avatar Vladimir Iglovikov Committed by Kai Chen
Browse files

Added Weight and Biases logger. (#143)

* Added wandb logger

* bugfix

* fixes

* fix

* flake8 fixes

* yapf fixes

* Refactor hook, add test

* flake8 fixes

* yapf fixes

* fix in test

* mock package => travis
parent aea75005
......@@ -8,7 +8,7 @@ before_install:
- sudo apt-get install -y ffmpeg
install:
- rm -rf .eggs && pip install -e . codecov flake8 yapf isort
- rm -rf .eggs && pip install -e . codecov flake8 yapf isort mock
cache:
pip: true
......
......@@ -3,7 +3,7 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
OptimizerHook, IterTimerHook, DistSamplerSeedHook,
LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook)
TensorboardLoggerHook, WandbLoggerHook)
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel_test import parallel_test
......@@ -15,7 +15,8 @@ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'parallel_test', 'Priority', 'get_priority', 'get_host_info',
'get_dist_info', 'master_only', 'get_time_str', 'obj_from_dict'
'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
]
......@@ -7,10 +7,11 @@ from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook)
TensorboardLoggerHook, WandbLoggerHook)
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
]
......@@ -2,7 +2,9 @@ from .base import LoggerHook
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
]
from ...utils import master_only
from .base import LoggerHook
import numbers
class WandbLoggerHook(LoggerHook):
def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.import_wandb()
def import_wandb(self):
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
if self.wandb is None:
self.import_wandb()
self.wandb.init()
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
runner.log_buffer.output[var]
if isinstance(val, numbers.Number):
metrics[tag] = val
if metrics:
self.wandb.log(metrics, step=runner.iter)
@master_only
def after_run(self, runner):
self.wandb.join()
import os.path as osp
import tempfile
import warnings
from mock import MagicMock
def test_save_checkpoint():
try:
import torch
import torch.nn as nn
from torch import nn
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
......@@ -27,3 +28,34 @@ def test_save_checkpoint():
assert osp.realpath(latest_path) == epoch1_path
torch.load(latest_path)
def test_wandb_hook():
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
wandb_mock = MagicMock()
hook = mmcv.runner.hooks.WandbLoggerHook()
hook.wandb = wandb_mock
loader = DataLoader(torch.ones((5, 5)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
"accuracy": 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
wandb_mock.init.assert_called()
wandb_mock.log.assert_called_with({'accuracy/val': 0.98}, step=5)
wandb_mock.join.assert_called()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment