Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e775877f
Commit
e775877f
authored
Sep 09, 2018
by
Myle Ott
Browse files
Add unit test to verify reproducibility after reloading checkpoints
parent
83e08b6f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
3 deletions
+91
-3
fairseq/optim/fp16_optimizer.py
fairseq/optim/fp16_optimizer.py
+1
-1
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/trainer.py
fairseq/trainer.py
+4
-2
tests/test_reproducibility.py
tests/test_reproducibility.py
+84
-0
No files found.
fairseq/optim/fp16_optimizer.py
View file @
e775877f
...
...
@@ -42,7 +42,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
self
.
fp32_optimizer
=
fp32_optimizer
self
.
fp32_params
=
fp32_params
self
.
scaler
=
DynamicLossScaler
(
init_scale
=
2.
**
7
,
init_scale
=
args
.
fp16_init_scale
,
scale_window
=
(
2
**
14
/
args
.
distributed_world_size
),
)
...
...
fairseq/options.py
View file @
e775877f
...
...
@@ -128,6 +128,8 @@ def get_parser(desc, default_task='translation'):
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'use FP16'
)
parser
.
add_argument
(
'--fp16-init-scale'
,
default
=
2
**
7
,
type
=
int
,
help
=
'default FP16 loss scale'
)
# Task definitions can be found under fairseq/tasks/
parser
.
add_argument
(
...
...
fairseq/trainer.py
View file @
e775877f
...
...
@@ -250,7 +250,8 @@ class Trainer(object):
)
self
.
meters
[
'oom'
].
update
(
ooms
)
self
.
meters
[
'train_loss'
].
update
(
logging_output
.
get
(
'loss'
,
0
),
sample_size
)
self
.
meters
[
'train_nll_loss'
].
update
(
logging_output
.
get
(
'nll_loss'
,
0
),
ntokens
)
if
'nll_loss'
in
logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
logging_output
.
get
(
'nll_loss'
,
0
),
ntokens
)
except
OverflowError
as
e
:
print
(
'| WARNING: overflow detected, '
+
str
(
e
))
self
.
zero_grad
()
...
...
@@ -301,7 +302,8 @@ class Trainer(object):
# update meters for validation
ntokens
=
logging_output
.
get
(
'ntokens'
,
0
)
self
.
meters
[
'valid_loss'
].
update
(
logging_output
.
get
(
'loss'
,
0
),
sample_size
)
self
.
meters
[
'valid_nll_loss'
].
update
(
logging_output
.
get
(
'nll_loss'
,
0
),
ntokens
)
if
'nll_loss'
in
logging_output
:
self
.
meters
[
'valid_nll_loss'
].
update
(
logging_output
.
get
(
'nll_loss'
,
0
),
ntokens
)
return
logging_output
...
...
tests/test_reproducibility.py
0 → 100644
View file @
e775877f
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
contextlib
from
io
import
StringIO
import
json
import
os
import
tempfile
import
unittest
import
torch
from
fairseq
import
options
from
.
import
test_binaries
class
TestReproducibility
(
unittest
.
TestCase
):
def
_test_reproducibility
(
self
,
name
,
extra_flags
=
None
):
if
extra_flags
is
None
:
extra_flags
=
[]
with
tempfile
.
TemporaryDirectory
(
name
)
as
data_dir
:
with
contextlib
.
redirect_stdout
(
StringIO
()):
test_binaries
.
create_dummy_data
(
data_dir
)
test_binaries
.
preprocess_translation_data
(
data_dir
)
# train epochs 1 and 2 together
stdout
=
StringIO
()
with
contextlib
.
redirect_stdout
(
stdout
):
test_binaries
.
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--dropout'
,
'0.0'
,
'--log-format'
,
'json'
,
'--log-interval'
,
'1'
,
'--max-epoch'
,
'3'
,
]
+
extra_flags
,
)
stdout
=
stdout
.
getvalue
()
train_log
,
valid_log
=
map
(
json
.
loads
,
stdout
.
split
(
'
\n
'
)[
-
4
:
-
2
])
# train epoch 2, resuming from previous checkpoint 1
os
.
rename
(
os
.
path
.
join
(
data_dir
,
'checkpoint1.pt'
),
os
.
path
.
join
(
data_dir
,
'checkpoint_last.pt'
),
)
stdout
=
StringIO
()
with
contextlib
.
redirect_stdout
(
stdout
):
test_binaries
.
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--dropout'
,
'0.0'
,
'--log-format'
,
'json'
,
'--log-interval'
,
'1'
,
'--max-epoch'
,
'3'
,
]
+
extra_flags
,
)
stdout
=
stdout
.
getvalue
()
train_res_log
,
valid_res_log
=
map
(
json
.
loads
,
stdout
.
split
(
'
\n
'
)[
-
4
:
-
2
])
def
cast
(
s
):
return
round
(
float
(
s
),
3
)
for
k
in
[
'loss'
,
'ppl'
,
'num_updates'
,
'gnorm'
]:
self
.
assertEqual
(
cast
(
train_log
[
k
]),
cast
(
train_res_log
[
k
]))
for
k
in
[
'valid_loss'
,
'valid_ppl'
,
'num_updates'
,
'best'
]:
self
.
assertEqual
(
cast
(
valid_log
[
k
]),
cast
(
valid_res_log
[
k
]))
def
test_reproducibility
(
self
):
self
.
_test_reproducibility
(
'test_reproducibility'
)
def
test_reproducibility_fp16
(
self
):
self
.
_test_reproducibility
(
'test_reproducibility_fp16'
,
[
'--fp16'
,
'--fp16-init-scale'
,
'4096'
,
])
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment