"vscode:/vscode.git/clone" did not exist on "04188377c54aa9073e4c2496ddd9996da9fda629"
Commit 080489e9 authored by Kai Chen's avatar Kai Chen
Browse files

set default log dir for TensorboardLoggerHook

parent 99f53d2a
......@@ -14,7 +14,7 @@ lr_config = dict(policy='step', step=2)
# runtime settings
work_dir = './demo'
gpus = range(2)
dist_params = dict(backend='gloo') # gloo is much slower than nccl
dist_params = dict(backend='nccl')
data_workers = 2 # data workers per gpu
checkpoint_config = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)]
......@@ -28,5 +28,5 @@ log_config = dict(
interval=50, # log at every 50 iterations
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
# dict(type='TensorboardLoggerHook'),
])
import os.path as osp
from .base import LoggerHook
from ...utils import master_only
class TensorboardLoggerHook(LoggerHook):
def __init__(self, log_dir, interval=10, ignore_last=True,
def __init__(self,
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
......@@ -18,6 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
raise ImportError('Please install tensorflow and tensorboardX '
'to use TensorboardLoggerHook.')
else:
if self.log_dir is None:
self.log_dir = osp.join(runner.work_dir, 'tf_logs')
self.writer = SummaryWriter(self.log_dir)
@master_only
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment