Commit 0daba38e authored by Myle Ott's avatar Myle Ott
Browse files

Save and restore wall time in checkpoints

parent dc40ac58
...@@ -28,10 +28,11 @@ class AverageMeter(object): ...@@ -28,10 +28,11 @@ class AverageMeter(object):
class TimeMeter(object): class TimeMeter(object):
"""Computes the average occurrence of some event per second""" """Computes the average occurrence of some event per second"""
def __init__(self): def __init__(self, init=0):
self.reset() self.reset(init)
def reset(self): def reset(self, init=0):
self.init = init
self.start = time.time() self.start = time.time()
self.n = 0 self.n = 0
...@@ -40,12 +41,11 @@ class TimeMeter(object): ...@@ -40,12 +41,11 @@ class TimeMeter(object):
@property @property
def avg(self): def avg(self):
delta = time.time() - self.start return self.n / self.elapsed_time
return self.n / delta
@property @property
def elapsed_time(self): def elapsed_time(self):
return time.time() - self.start return self.init + (time.time() - self.start)
class StopwatchMeter(object): class StopwatchMeter(object):
......
...@@ -273,6 +273,7 @@ def save_checkpoint(trainer, args, epoch, val_loss=None): ...@@ -273,6 +273,7 @@ def save_checkpoint(trainer, args, epoch, val_loss=None):
extra_state = { extra_state = {
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
} }
if not args.no_epoch_checkpoints: if not args.no_epoch_checkpoints:
...@@ -302,6 +303,7 @@ def load_checkpoint(args, trainer, train_dataloader): ...@@ -302,6 +303,7 @@ def load_checkpoint(args, trainer, train_dataloader):
for i in range(epoch): for i in range(epoch):
_ = next(train_dataloader) _ = next(train_dataloader)
epoch += 1 epoch += 1
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0))
return epoch return epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment