Commit 7d560402 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

record end_of_epoch in checkpoint

parent 978c125a
......@@ -13,9 +13,9 @@ from unittest.mock import MagicMock, patch
import train
def mock_trainer(epoch, num_updates):
def mock_trainer(epoch, num_updates, end_of_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch}
trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
trainer.get_num_updates.return_value = num_updates
return trainer
......@@ -38,21 +38,21 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
trainer = mock_trainer(2, 200)
trainer = mock_trainer(2, 200, False)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150)
trainer = mock_trainer(2, 300, True)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(epoch, 3)
self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0)
trainer = mock_trainer(0, 0, False)
loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False
......
......@@ -280,6 +280,7 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
'epoch': epoch,
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
'end_of_epoch': end_of_epoch,
}
if end_of_epoch and not args.no_epoch_checkpoints:
......@@ -314,9 +315,10 @@ def load_checkpoint(args, trainer, train_dataloader):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
end_of_epoch = extra_state.get('end_of_epoch', True)
trainer_updates = trainer.get_num_updates()
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(checkpoint_path, epoch, trainer_updates))
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
updates = 0
......@@ -324,14 +326,18 @@ def load_checkpoint(args, trainer, train_dataloader):
ds = next(train_dataloader)
updates += len(ds)
if ds is not None and updates > trainer_updates:
if not end_of_epoch and ds is not None and updates > trainer_updates:
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
print('| resuming from batch {}'.format(completed_batches + 1))
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else:
if not end_of_epoch:
print('| WARNING: checkpoint is not at end of epoch')
ds = next(train_dataloader)
epoch += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment