Unverified Commit eacaf475 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

fix some pavi logger hooks (#481)

* fix some pavi logger hooks

* fix unittest

* fix small bugs

* small change

* fix unittest

* Add EpochBasedRunner conditions

* Add session text

* fix small bug

* fetch runner mode from log buffer

* Add max_iter to pavi session text

* change yaml.dump to yamp.dump(yaml.load(mmcv.dump))

* Directly use by_epoch

* fix unittest

* add comments

* Use runner.epoch + 1 in pavi log

* fix runner.epoch issue for runner.mode=='val'

* fix runner.epoch issue for runner.mode=='val'

* Use abspath instead of realpath

* Add meta dump unittest

* small change

* Add comments
parent c8e85b28
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import json
import numbers import numbers
import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import torch import torch
import yaml
import mmcv
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -61,7 +65,16 @@ class PaviLoggerHook(LoggerHook): ...@@ -61,7 +65,16 @@ class PaviLoggerHook(LoggerHook):
self.init_kwargs = dict() self.init_kwargs = dict()
self.init_kwargs['task'] = self.run_name self.init_kwargs['task'] = self.run_name
self.init_kwargs['model'] = runner._model_name self.init_kwargs['model'] = runner._model_name
if runner.meta is not None and 'config_dict' in runner.meta:
config_dict = runner.meta['config_dict'].copy()
# 'max_.*iter' is parsed in pavi sdk as the maximum iterations
# to properly set up the progress bar.
config_dict.setdefault('max_iter', runner.max_iters)
# non-serializable values are first converted in mmcv.dump to json
config_dict = json.loads(
mmcv.dump(config_dict, file_format='json'))
session_text = yaml.dump(config_dict)
self.init_kwargs['session_text'] = session_text
self.writer = SummaryWriter(**self.init_kwargs) self.writer = SummaryWriter(**self.init_kwargs)
if self.add_graph: if self.add_graph:
...@@ -90,13 +103,27 @@ class PaviLoggerHook(LoggerHook): ...@@ -90,13 +103,27 @@ class PaviLoggerHook(LoggerHook):
tags['momentum'] = momentums[0] tags['momentum'] = momentums[0]
if tags: if tags:
self.writer.add_scalars(runner.mode, tags, runner.iter) if runner.mode == 'val':
mode = runner.mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
epoch = runner.epoch + 1
if mode == 'val' and self.by_epoch:
self.writer.add_scalars(mode, tags, epoch)
else:
self.writer.add_scalars(mode, tags, runner.iter)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
if self.add_last_ckpt: if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth') ckpt_path = osp.join(runner.work_dir, 'latest.pth')
self.writer.add_snapshot_file( if osp.isfile(ckpt_path):
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
# runner.epoch += 1 has been done before `after_run`.
iteration = runner.epoch if self.by_epoch else runner.iter
return self.writer.add_snapshot_file(
tag=self.run_name, tag=self.run_name,
snapshot_file_path=ckpt_path, snapshot_file_path=ckpt_path,
iteration=runner.iter) iteration=iteration)
...@@ -98,6 +98,7 @@ def test_pavi_hook(): ...@@ -98,6 +98,7 @@ def test_pavi_hook():
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner() runner = _build_demo_runner()
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True) hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
...@@ -107,11 +108,11 @@ def test_pavi_hook(): ...@@ -107,11 +108,11 @@ def test_pavi_hook():
hook.writer.add_scalars.assert_called_with('val', { hook.writer.add_scalars.assert_called_with('val', {
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, 5) }, 1)
hook.writer.add_snapshot_file.assert_called_with( hook.writer.add_snapshot_file.assert_called_with(
tag=runner.work_dir.split('/')[-1], tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'), snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
iteration=5) iteration=1)
def test_sync_buffers_hook(): def test_sync_buffers_hook():
...@@ -378,6 +379,6 @@ def _build_demo_runner(): ...@@ -378,6 +379,6 @@ def _build_demo_runner():
work_dir=tmp_dir, work_dir=tmp_dir,
optimizer=optimizer, optimizer=optimizer,
logger=logging.getLogger()) logger=logging.getLogger())
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config) runner.register_logger_hooks(log_config)
return runner return runner
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment