"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "0e4b7a3929e12d1645e3e177148d15cd4cdec793"
Commit 8266d877 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

add json log support (#55)

* add json log support

* add some comment, minor fix

* rename variable vv to item

* mv round float to a function
parent 506455af
...@@ -48,6 +48,8 @@ class LoggerHook(Hook): ...@@ -48,6 +48,8 @@ class LoggerHook(Hook):
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if runner.log_buffer.ready: if runner.log_buffer.ready:
self.log(runner) self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_val_epoch(self, runner): def after_val_epoch(self, runner):
runner.log_buffer.average() runner.log_buffer.average()
......
import datetime import datetime
import os.path as osp
from collections import OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import mmcv
from .base import LoggerHook from .base import LoggerHook
...@@ -15,6 +18,8 @@ class TextLoggerHook(LoggerHook): ...@@ -15,6 +18,8 @@ class TextLoggerHook(LoggerHook):
def before_run(self, runner): def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner) super(TextLoggerHook, self).before_run(runner)
self.start_iter = runner.iter self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir,
'{}.log.json'.format(runner.timestamp))
def _get_max_memory(self, runner): def _get_max_memory(self, runner):
mem = torch.cuda.max_memory_allocated() mem = torch.cuda.max_memory_allocated()
...@@ -23,40 +28,79 @@ class TextLoggerHook(LoggerHook): ...@@ -23,40 +28,79 @@ class TextLoggerHook(LoggerHook):
device=torch.device('cuda')) device=torch.device('cuda'))
if runner.world_size > 1: if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb return mem_mb.item()
def log(self, runner): def _log_info(self, log_dict, runner):
if runner.mode == 'train': if runner.mode == 'train':
lr_str = ', '.join( lr_str = ', '.join(['{:.5f}'.format(lr) for lr in log_dict['lr']])
['{:.5f}'.format(lr) for lr in runner.current_lr()])
log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format( log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format(
runner.epoch + 1, runner.inner_iter + 1, log_dict['epoch'], log_dict['iter'], len(runner.data_loader),
len(runner.data_loader), lr_str) lr_str)
if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += 'eta: {}, '.format(eta_str)
log_str += ('time: {:.3f}, data_time: {:.3f}, '.format(
log_dict['time'], log_dict['data_time']))
log_str += 'memory: {}, '.format(log_dict['memory'])
else: else:
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, log_str = 'Epoch({}) [{}][{}]\t'.format(
runner.inner_iter + 1) log_dict['mode'], log_dict['epoch'] - 1, log_dict['iter'])
if 'time' in runner.log_buffer.output:
self.time_sec_tot += (
runner.log_buffer.output['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += 'eta: {}, '.format(eta_str)
log_str += (
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
format(log=runner.log_buffer.output))
# statistic memory
# training mode if the output contains the key "time"
if 'time' in runner.log_buffer.output and torch.cuda.is_available():
mem_mb = self._get_max_memory(runner)
log_str += 'memory: {}, '.format(mem_mb.item())
log_items = [] log_items = []
for name, val in runner.log_buffer.output.items(): for name, val in log_dict.items():
if name in ['time', 'data_time']: # TODO: resolve this hack
# these items have been in log_str
if name in [
'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
'memory', 'epoch'
]:
continue continue
if isinstance(val, float): if isinstance(val, float):
val = '{:.4f}'.format(val) val = '{:.4f}'.format(val)
log_items.append('{}: {}'.format(name, val)) log_items.append('{}: {}'.format(name, val))
log_str += ', '.join(log_items) log_str += ', '.join(log_items)
runner.logger.info(log_str) runner.logger.info(log_str)
def _dump_log(self, log_dict, runner):
# dump log in json format
json_log = OrderedDict()
for k, v in log_dict.items():
json_log[k] = self._round_float(v)
# only append log at last line
if runner.rank == 0:
with open(self.json_log_path, 'a+') as f:
mmcv.dump(json_log, f, file_format='json')
f.write('\n')
def _round_float(self, items):
if isinstance(items, list):
return [self._round_float(item) for item in items]
elif isinstance(items, float):
return round(items, 5)
else:
return items
def log(self, runner):
log_dict = OrderedDict()
# training mode if the output contains the key "time"
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
log_dict['mode'] = mode
log_dict['epoch'] = runner.epoch + 1
log_dict['iter'] = runner.inner_iter + 1
log_dict['lr'] = [lr for lr in runner.current_lr()]
if mode == 'train':
log_dict['time'] = runner.log_buffer.output['time']
log_dict['data_time'] = runner.log_buffer.output['data_time']
# statistic memory
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)
for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']:
continue
log_dict[name] = val
self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
...@@ -62,6 +62,7 @@ class Runner(object): ...@@ -62,6 +62,7 @@ class Runner(object):
self._model_name = self.model.__class__.__name__ self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info() self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
if logger is None: if logger is None:
self.logger = self.init_logger(work_dir, log_level) self.logger = self.init_logger(work_dir, log_level)
else: else:
...@@ -174,7 +175,7 @@ class Runner(object): ...@@ -174,7 +175,7 @@ class Runner(object):
format='%(asctime)s - %(levelname)s - %(message)s', level=level) format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if log_dir and self.rank == 0: if log_dir and self.rank == 0:
filename = '{}.log'.format(get_time_str()) filename = '{}.log'.format(self.timestamp)
log_file = osp.join(log_dir, filename) log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level) self._add_file_handler(logger, log_file, level=level)
return logger return logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment