Commit 0b5166db authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix tests

parent 2dc074d8
......@@ -52,6 +52,8 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.args_mock = MagicMock()
self.args_mock.optimizer_overrides = '{}'
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
......@@ -60,11 +62,12 @@ class TestLoadCheckpoint(unittest.TestCase):
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
......@@ -79,7 +82,7 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
......@@ -91,7 +94,7 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment