Unverified Commit ac78bdc2 authored by Jiangmiao Pang's avatar Jiangmiao Pang Committed by GitHub
Browse files

Add PaviLoggerHook (#198)

* fix wandb hook

* add pavi logger hook

* modify default value of init_kwargs to None

* add more features to PaviLoggerHook

* rm taskid (will cause bug)

* directly upload saved pth

* fix CI

* register PAVI Hooks
parent 0d5332a4
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict, from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu) save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook, from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, HOOKS,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook, Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer from .log_buffer import LogBuffer
from .parallel_test import parallel_test from .parallel_test import parallel_test
from .priority import Priority, get_priority from .priority import Priority, get_priority
...@@ -12,11 +13,11 @@ from .runner import Runner ...@@ -12,11 +13,11 @@ from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [ __all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook', 'Runner', 'LogBuffer', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook', 'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'_load_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist', 'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'get_dist_info', 'master_only' 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
] ]
...@@ -3,8 +3,8 @@ from .checkpoint import CheckpointHook ...@@ -3,8 +3,8 @@ from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .logger import (LoggerHook, TensorboardLoggerHook, TextLoggerHook, from .logger import (LoggerHook, PaviLoggerHook, TensorboardLoggerHook,
WandbLoggerHook) TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook from .memory import EmptyCacheHook
from .optimizer import OptimizerHook from .optimizer import OptimizerHook
...@@ -13,5 +13,6 @@ from .sampler_seed import DistSamplerSeedHook ...@@ -13,5 +13,6 @@ from .sampler_seed import DistSamplerSeedHook
__all__ = [ __all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook' 'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook from .base import LoggerHook
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook from .text import TextLoggerHook
from .wandb import WandbLoggerHook from .wandb import WandbLoggerHook
__all__ = [ __all__ = [
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook' 'LoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'TextLoggerHook',
'WandbLoggerHook'
] ]
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class PaviLoggerHook(LoggerHook):
def __init__(self,
init_kwargs=None,
add_graph=False,
add_last_ckpt=False,
interval=10,
ignore_last=True,
reset_flag=True):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
@master_only
def before_run(self, runner):
try:
from pavi import SummaryWriter
except ImportError:
raise ImportError('Please run "pip install pavi" to install pavi.')
self.run_name = runner.work_dir.split('/')[-1]
if not self.init_kwargs:
self.init_kwargs = dict()
self.init_kwargs['task'] = self.run_name
self.init_kwargs['model'] = runner._model_name
self.writer = SummaryWriter(**self.init_kwargs)
if self.add_graph:
self.writer.add_graph(runner.model)
@master_only
def after_run(self, runner):
if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
self.writer.add_snapshot_file(
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=runner.iter)
@master_only
def log(self, runner):
tags = {}
for tag, val in runner.log_buffer.output.items():
if tag in ['time', 'data_time']:
continue
tags[tag] = val
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)
...@@ -10,13 +10,14 @@ from .base import LoggerHook ...@@ -10,13 +10,14 @@ from .base import LoggerHook
class WandbLoggerHook(LoggerHook): class WandbLoggerHook(LoggerHook):
def __init__(self, def __init__(self,
log_dir=None, init_kwargs=None,
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=True): reset_flag=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last, super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag) reset_flag)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs
def import_wandb(self): def import_wandb(self):
try: try:
...@@ -30,7 +31,10 @@ class WandbLoggerHook(LoggerHook): ...@@ -30,7 +31,10 @@ class WandbLoggerHook(LoggerHook):
def before_run(self, runner): def before_run(self, runner):
if self.wandb is None: if self.wandb is None:
self.import_wandb() self.import_wandb()
self.wandb.init() if self.init_kwargs:
self.wandb.init(**self.init_kwargs)
else:
self.wandb.init()
@master_only @master_only
def log(self, runner): def log(self, runner):
...@@ -39,7 +43,6 @@ class WandbLoggerHook(LoggerHook): ...@@ -39,7 +43,6 @@ class WandbLoggerHook(LoggerHook):
if var in ['time', 'data_time']: if var in ['time', 'data_time']:
continue continue
tag = '{}/{}'.format(var, runner.mode) tag = '{}/{}'.format(var, runner.mode)
runner.log_buffer.output[var]
if isinstance(val, numbers.Number): if isinstance(val, numbers.Number):
metrics[tag] = val metrics[tag] = val
if metrics: if metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment