Commit 736fbee2 authored by Myle Ott's avatar Myle Ott
Browse files

Suppress stdout in test_train

parent 13aa36cf
......@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO
import unittest
from unittest.mock import MagicMock, patch
......@@ -37,6 +39,7 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 200, False)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
......@@ -44,6 +47,7 @@ class TestLoadCheckpoint(unittest.TestCase):
self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(2, 300, True)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
......@@ -51,6 +55,7 @@ class TestLoadCheckpoint(unittest.TestCase):
self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer = mock_trainer(0, 0, False)
loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment