"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "971c3e45b96bc5aa5868c45cd40e4f3c3d90d126"
Commit 080489e9 authored by Kai Chen's avatar Kai Chen
Browse files

set default log dir for TensorboardLoggerHook

parent 99f53d2a
...@@ -14,7 +14,7 @@ lr_config = dict(policy='step', step=2) ...@@ -14,7 +14,7 @@ lr_config = dict(policy='step', step=2)
# runtime settings # runtime settings
work_dir = './demo' work_dir = './demo'
gpus = range(2) gpus = range(2)
dist_params = dict(backend='gloo') # gloo is much slower than nccl dist_params = dict(backend='nccl')
data_workers = 2 # data workers per gpu data_workers = 2 # data workers per gpu
checkpoint_config = dict(interval=1) # save checkpoint at every epoch checkpoint_config = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)] workflow = [('train', 1), ('val', 1)]
...@@ -28,5 +28,5 @@ log_config = dict( ...@@ -28,5 +28,5 @@ log_config = dict(
interval=50, # log at every 50 iterations interval=50, # log at every 50 iterations
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'), # dict(type='TensorboardLoggerHook'),
]) ])
import os.path as osp
from .base import LoggerHook from .base import LoggerHook
from ...utils import master_only from ...utils import master_only
class TensorboardLoggerHook(LoggerHook): class TensorboardLoggerHook(LoggerHook):
def __init__(self, log_dir, interval=10, ignore_last=True, def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True): reset_flag=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last, super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag) reset_flag)
...@@ -18,6 +23,8 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -18,6 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
raise ImportError('Please install tensorflow and tensorboardX ' raise ImportError('Please install tensorflow and tensorboardX '
'to use TensorboardLoggerHook.') 'to use TensorboardLoggerHook.')
else: else:
if self.log_dir is None:
self.log_dir = osp.join(runner.work_dir, 'tf_logs')
self.writer = SummaryWriter(self.log_dir) self.writer = SummaryWriter(self.log_dir)
@master_only @master_only
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment