from .base import LoggerHook
from ...utils import master_only


class TensorboardLoggerHook(LoggerHook):

    def __init__(self, log_dir, interval=10, ignore_last=True,
                 reset_flag=True):
        super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
                                                    reset_flag)
        self.log_dir = log_dir

    @master_only
    def before_run(self, runner):
        try:
            from tensorboardX import SummaryWriter
        except ImportError:
            raise ImportError('Please install tensorflow and tensorboardX '
                              'to use TensorboardLoggerHook.')
        else:
            self.writer = SummaryWriter(self.log_dir)

    @master_only
    def log(self, runner):
        for var in runner.log_buffer.output:
            if var in ['time', 'data_time']:
                continue
            tag = '{}/{}'.format(var, runner.mode)
            self.writer.add_scalar(tag, runner.log_buffer.output[var],
                                   runner.iter)

    @master_only
    def after_run(self, runner):
        self.writer.close()
