Unverified Commit 8b4417c1 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Fix the time estimation when resuming from a checkpoint (#37)

* fix the time estimation when resuming from a checkpoint

* fix the time estimation when resuming from a checkpoint
parent 1d6e91b1
...@@ -9,6 +9,10 @@ class TextLoggerHook(LoggerHook): ...@@ -9,6 +9,10 @@ class TextLoggerHook(LoggerHook):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag) super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.time_sec_tot = 0 self.time_sec_tot = 0
def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner)
self.start_iter = runner.iter
def log(self, runner): def log(self, runner):
if runner.mode == 'train': if runner.mode == 'train':
lr_str = ', '.join( lr_str = ', '.join(
...@@ -20,9 +24,10 @@ class TextLoggerHook(LoggerHook): ...@@ -20,9 +24,10 @@ class TextLoggerHook(LoggerHook):
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
runner.inner_iter + 1) runner.inner_iter + 1)
if 'time' in runner.log_buffer.output: if 'time' in runner.log_buffer.output:
self.time_sec_tot += (runner.log_buffer.output['time'] * self.time_sec_tot += (
self.interval) runner.log_buffer.output['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter + 1) time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec))) eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += 'eta: {}, '.format(eta_str) log_str += 'eta: {}, '.format(eta_str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment