"docs/en/get_started/installation.md" did not exist on "cdcbc03c98d2a60a8c599182f125a2cd21e32bec"
Unverified Commit ac78bdc2 authored by Jiangmiao Pang's avatar Jiangmiao Pang Committed by GitHub
Browse files

Add PaviLoggerHook (#198)

* fix wandb hook

* add pavi logger hook

* modify default value of init_kwargs to None

* add more features to PaviLoggerHook

* rm taskid (will cause bug)

* directly upload saved pth

* fix CI

* register PAVI Hooks
parent 0d5332a4
......@@ -2,9 +2,10 @@
from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu)
from .dist_utils import get_dist_info, init_dist, master_only
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, LrUpdaterHook, OptimizerHook,
TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook)
from .hooks import (CheckpointHook, ClosureHook, DistSamplerSeedHook, HOOKS,
Hook, IterTimerHook, LoggerHook, LrUpdaterHook,
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer
from .parallel_test import parallel_test
from .priority import Priority, get_priority
......@@ -12,11 +13,11 @@ from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict
__all__ = [
'Runner', 'LogBuffer', 'Hook', 'CheckpointHook', 'ClosureHook',
'Runner', 'LogBuffer', 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook',
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook',
'_load_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu',
'save_checkpoint', 'parallel_test', 'Priority', 'get_priority',
'get_host_info', 'get_time_str', 'obj_from_dict', 'init_dist',
'get_dist_info', 'master_only'
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
]
......@@ -3,8 +3,8 @@ from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .logger import (LoggerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .lr_updater import LrUpdaterHook
from .memory import EmptyCacheHook
from .optimizer import OptimizerHook
......@@ -13,5 +13,6 @@ from .sampler_seed import DistSamplerSeedHook
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook'
]
# Copyright (c) Open-MMLab. All rights reserved.
from .base import LoggerHook
from .pavi import PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from .wandb import WandbLoggerHook
__all__ = [
'LoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', 'WandbLoggerHook'
'LoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'TextLoggerHook',
'WandbLoggerHook'
]
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from mmcv.runner import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module
class PaviLoggerHook(LoggerHook):
def __init__(self,
init_kwargs=None,
add_graph=False,
add_last_ckpt=False,
interval=10,
ignore_last=True,
reset_flag=True):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
@master_only
def before_run(self, runner):
try:
from pavi import SummaryWriter
except ImportError:
raise ImportError('Please run "pip install pavi" to install pavi.')
self.run_name = runner.work_dir.split('/')[-1]
if not self.init_kwargs:
self.init_kwargs = dict()
self.init_kwargs['task'] = self.run_name
self.init_kwargs['model'] = runner._model_name
self.writer = SummaryWriter(**self.init_kwargs)
if self.add_graph:
self.writer.add_graph(runner.model)
@master_only
def after_run(self, runner):
if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
self.writer.add_snapshot_file(
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=runner.iter)
@master_only
def log(self, runner):
tags = {}
for tag, val in runner.log_buffer.output.items():
if tag in ['time', 'data_time']:
continue
tags[tag] = val
if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter)
......@@ -10,13 +10,14 @@ from .base import LoggerHook
class WandbLoggerHook(LoggerHook):
def __init__(self,
log_dir=None,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.import_wandb()
self.init_kwargs = init_kwargs
def import_wandb(self):
try:
......@@ -30,6 +31,9 @@ class WandbLoggerHook(LoggerHook):
def before_run(self, runner):
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
self.wandb.init(**self.init_kwargs)
else:
self.wandb.init()
@master_only
......@@ -39,7 +43,6 @@ class WandbLoggerHook(LoggerHook):
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
runner.log_buffer.output[var]
if isinstance(val, numbers.Number):
metrics[tag] = val
if metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment