"src/vscode:/vscode.git/clone" did not exist on "b56f102765707ce34c7016409a8b901942edb977"
Commit 7d560402 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

record end_of_epoch in checkpoint

parent 978c125a
...@@ -13,9 +13,9 @@ from unittest.mock import MagicMock, patch ...@@ -13,9 +13,9 @@ from unittest.mock import MagicMock, patch
import train import train
def mock_trainer(epoch, num_updates): def mock_trainer(epoch, num_updates, end_of_epoch):
trainer = MagicMock() trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch} trainer.load_checkpoint.return_value = {'epoch': epoch, 'end_of_epoch': end_of_epoch}
trainer.get_num_updates.return_value = num_updates trainer.get_num_updates.return_value = num_updates
return trainer return trainer
...@@ -38,21 +38,21 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -38,21 +38,21 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches] [p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
trainer = mock_trainer(2, 200) trainer = mock_trainer(2, 200, False)
loader = mock_loader(150) loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2) self.assertEqual(epoch, 2)
self.assertEqual(next(ds), 50) self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150) trainer = mock_trainer(2, 300, True)
loader = mock_loader(150) loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2) self.assertEqual(epoch, 3)
self.assertEqual(next(iter(ds)), 0) self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self): def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0) trainer = mock_trainer(0, 0, False)
loader = mock_loader(150) loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False self.patches['os.path.isfile'].return_value = False
......
...@@ -280,6 +280,7 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss): ...@@ -280,6 +280,7 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time, 'wall_time': trainer.get_meter('wall').elapsed_time,
'end_of_epoch': end_of_epoch,
} }
if end_of_epoch and not args.no_epoch_checkpoints: if end_of_epoch and not args.no_epoch_checkpoints:
...@@ -314,9 +315,10 @@ def load_checkpoint(args, trainer, train_dataloader): ...@@ -314,9 +315,10 @@ def load_checkpoint(args, trainer, train_dataloader):
extra_state = trainer.load_checkpoint(checkpoint_path) extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None: if extra_state is not None:
epoch = extra_state['epoch'] epoch = extra_state['epoch']
end_of_epoch = extra_state.get('end_of_epoch', True)
trainer_updates = trainer.get_num_updates() trainer_updates = trainer.get_num_updates()
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(checkpoint_path, epoch, trainer_updates)) print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch) trainer.lr_step(epoch)
updates = 0 updates = 0
...@@ -324,14 +326,18 @@ def load_checkpoint(args, trainer, train_dataloader): ...@@ -324,14 +326,18 @@ def load_checkpoint(args, trainer, train_dataloader):
ds = next(train_dataloader) ds = next(train_dataloader)
updates += len(ds) updates += len(ds)
if ds is not None and updates > trainer_updates: if not end_of_epoch and ds is not None and updates > trainer_updates:
completed_batches = len(ds) - (updates - trainer_updates) completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0 assert completed_batches >= 0
ds = iter(ds) ds = iter(ds)
print('| resuming from batch {}'.format(completed_batches + 1))
# consume completed batches # consume completed batches
next(islice(ds, completed_batches, completed_batches), None) next(islice(ds, completed_batches, completed_batches), None)
else: else:
if not end_of_epoch:
print('| WARNING: checkpoint is not at end of epoch')
ds = next(train_dataloader) ds = next(train_dataloader)
epoch += 1 epoch += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment