Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
736fbee2
Commit
736fbee2
authored
Jun 04, 2018
by
Myle Ott
Browse files
Suppress stdout in test_train
parent
13aa36cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
17 deletions
+22
-17
tests/test_train.py
tests/test_train.py
+22
-17
No files found.
tests/test_train.py
View file @
736fbee2
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
contextlib
from
io
import
StringIO
import
unittest
import
unittest
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
...
@@ -37,27 +39,30 @@ class TestLoadCheckpoint(unittest.TestCase):
...
@@ -37,27 +39,30 @@ class TestLoadCheckpoint(unittest.TestCase):
[
p
.
start
()
for
p
in
self
.
applied_patches
]
[
p
.
start
()
for
p
in
self
.
applied_patches
]
def
test_load_partial_checkpoint
(
self
):
def
test_load_partial_checkpoint
(
self
):
trainer
=
mock_trainer
(
2
,
200
,
False
)
with
contextlib
.
redirect_stdout
(
StringIO
()):
loader
=
mock_loader
(
150
)
trainer
=
mock_trainer
(
2
,
200
,
False
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
loader
=
mock_loader
(
150
)
self
.
assertEqual
(
epoch
,
2
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
next
(
ds
),
50
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
next
(
ds
),
50
)
def
test_load_full_checkpoint
(
self
):
def
test_load_full_checkpoint
(
self
):
trainer
=
mock_trainer
(
2
,
300
,
True
)
with
contextlib
.
redirect_stdout
(
StringIO
()):
loader
=
mock_loader
(
150
)
trainer
=
mock_trainer
(
2
,
300
,
True
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
loader
=
mock_loader
(
150
)
self
.
assertEqual
(
epoch
,
3
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
self
.
assertEqual
(
epoch
,
3
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
def
test_load_no_checkpoint
(
self
):
def
test_load_no_checkpoint
(
self
):
trainer
=
mock_trainer
(
0
,
0
,
False
)
with
contextlib
.
redirect_stdout
(
StringIO
()):
loader
=
mock_loader
(
150
)
trainer
=
mock_trainer
(
0
,
0
,
False
)
self
.
patches
[
'os.path.isfile'
].
return_value
=
False
loader
=
mock_loader
(
150
)
self
.
patches
[
'os.path.isfile'
].
return_value
=
False
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
1
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
self
.
assertEqual
(
epoch
,
1
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
def
tearDown
(
self
):
def
tearDown
(
self
):
patch
.
stopall
()
patch
.
stopall
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment