Commit 56971278 authored by Vladimir Iglovikov's avatar Vladimir Iglovikov Committed by Kai Chen
Browse files

Added Weight and Biases logger. (#143)

* Added wandb logger

* bugfix

* fixes

* fix

* flake8 fixes

* yapf fixes

* Refactor hook, add test

* flake8 fixes

* yapf fixes

* fix in test

* mock package => travis
parent aea75005
...@@ -8,7 +8,7 @@ before_install: ...@@ -8,7 +8,7 @@ before_install:
- sudo apt-get install -y ffmpeg - sudo apt-get install -y ffmpeg
install: install:
- rm -rf .eggs && pip install -e . codecov flake8 yapf isort - rm -rf .eggs && pip install -e . codecov flake8 yapf isort mock
cache: cache:
pip: true pip: true
......
...@@ -3,7 +3,7 @@ from .log_buffer import LogBuffer ...@@ -3,7 +3,7 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook, from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
OptimizerHook, IterTimerHook, DistSamplerSeedHook, OptimizerHook, IterTimerHook, DistSamplerSeedHook,
LoggerHook, TextLoggerHook, PaviLoggerHook, LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook) TensorboardLoggerHook, WandbLoggerHook)
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu, from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint) save_checkpoint)
from .parallel_test import parallel_test from .parallel_test import parallel_test
...@@ -15,7 +15,8 @@ __all__ = [ ...@@ -15,7 +15,8 @@ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'WandbLoggerHook', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'parallel_test', 'Priority', 'get_priority', 'get_host_info', 'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_dist_info', 'master_only', 'get_time_str', 'obj_from_dict' 'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
] ]
...@@ -7,10 +7,11 @@ from .iter_timer import IterTimerHook ...@@ -7,10 +7,11 @@ from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook) TensorboardLoggerHook, WandbLoggerHook)
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook' 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
] ]
...@@ -2,7 +2,9 @@ from .base import LoggerHook ...@@ -2,7 +2,9 @@ from .base import LoggerHook
from .pavi import PaviLoggerHook from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook from .text import TextLoggerHook
from .wandb import WandbLoggerHook
__all__ = [ __all__ = [
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook' 'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
] ]
from ...utils import master_only
from .base import LoggerHook
import numbers
class WandbLoggerHook(LoggerHook):
def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.import_wandb()
def import_wandb(self):
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
if self.wandb is None:
self.import_wandb()
self.wandb.init()
@master_only
def log(self, runner):
metrics = {}
for var, val in runner.log_buffer.output.items():
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
runner.log_buffer.output[var]
if isinstance(val, numbers.Number):
metrics[tag] = val
if metrics:
self.wandb.log(metrics, step=runner.iter)
@master_only
def after_run(self, runner):
self.wandb.join()
import os.path as osp import os.path as osp
import tempfile import tempfile
import warnings import warnings
from mock import MagicMock
def test_save_checkpoint(): def test_save_checkpoint():
try: try:
import torch import torch
import torch.nn as nn from torch import nn
except ImportError: except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch') warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return return
...@@ -27,3 +28,34 @@ def test_save_checkpoint(): ...@@ -27,3 +28,34 @@ def test_save_checkpoint():
assert osp.realpath(latest_path) == epoch1_path assert osp.realpath(latest_path) == epoch1_path
torch.load(latest_path) torch.load(latest_path)
def test_wandb_hook():
try:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
wandb_mock = MagicMock()
hook = mmcv.runner.hooks.WandbLoggerHook()
hook.wandb = wandb_mock
loader = DataLoader(torch.ones((5, 5)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
"accuracy": 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
wandb_mock.init.assert_called()
wandb_mock.log.assert_called_with({'accuracy/val': 0.98}, step=5)
wandb_mock.join.assert_called()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment