Commit 736fbee2 authored by Myle Ott's avatar Myle Ott
Browse files

Suppress stdout in test_train

parent 13aa36cf
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
...@@ -37,27 +39,30 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -37,27 +39,30 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches] [p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
trainer = mock_trainer(2, 200, False) with contextlib.redirect_stdout(StringIO()):
loader = mock_loader(150) trainer = mock_trainer(2, 200, False)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) loader = mock_loader(150)
self.assertEqual(epoch, 2) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(next(ds), 50) self.assertEqual(epoch, 2)
self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 300, True) with contextlib.redirect_stdout(StringIO()):
loader = mock_loader(150) trainer = mock_trainer(2, 300, True)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) loader = mock_loader(150)
self.assertEqual(epoch, 3) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(next(iter(ds)), 0) self.assertEqual(epoch, 3)
self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self): def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0, False) with contextlib.redirect_stdout(StringIO()):
loader = mock_loader(150) trainer = mock_trainer(0, 0, False)
self.patches['os.path.isfile'].return_value = False loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(next(iter(ds)), 0) self.assertEqual(epoch, 1)
self.assertEqual(next(iter(ds)), 0)
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment