Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f7b2b24
Unverified
Commit
9f7b2b24
authored
Oct 17, 2020
by
Stas Bekman
Committed by
GitHub
Oct 17, 2020
Browse files
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
parent
dc552b9b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
674 additions
and
694 deletions
+674
-694
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+162
-162
examples/seq2seq/test_data/wmt_en_ro/train.len
examples/seq2seq/test_data/wmt_en_ro/train.len
+0
-0
examples/seq2seq/test_data/wmt_en_ro/val.len
examples/seq2seq/test_data/wmt_en_ro/val.len
+0
-0
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+191
-195
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+62
-64
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+259
-273
No files found.
examples/seq2seq/test_bash_script.py
View file @
9f7b2b24
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
argparse
import
argparse
import
os
import
os
import
sys
import
sys
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
load_json
from
utils
import
load_json
...
@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY
...
@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
slow
def
test_model_download
():
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_model_download
(
self
):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
(
self
):
def
test_train_mbart_cc25_enro_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"--fp16_opt_level=O1"
:
""
,
...
@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script():
...
@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script():
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_mbart"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
testargs
=
(
...
@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script():
...
@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script():
metrics
=
load_json
(
model
.
metrics_save_path
)
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
(
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
...
@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script():
...
@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script():
# assert len(metrics["val"]) == desired_n_evals
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
(
self
):
def
test_opus_mt_distill_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"--fp16_opt_level=O1"
:
""
,
...
@@ -131,7 +131,7 @@ def test_opus_mt_distill_script():
...
@@ -131,7 +131,7 @@ def test_opus_mt_distill_script():
for
k
,
v
in
env_vars_to_replace
.
items
():
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
epochs
=
6
testargs
=
(
testargs
=
(
...
...
examples/seq2seq/test_data/wmt_en_ro/train.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_data/wmt_en_ro/val.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_datasets.py
View file @
9f7b2b24
import
os
import
os
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -7,11 +6,12 @@ import pytest
...
@@ -7,11 +6,12 @@ import pytest
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
pack_dataset
import
pack_data_dir
from
pack_dataset
import
pack_data_dir
from
parameterized
import
parameterized
from
save_len_file
import
save_len_file
from
save_len_file
import
save_len_file
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
...
@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased"
...
@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
PEGASUS_XSUM
=
"google/pegasus-xsum"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
parametrize
(
@
parameterized
.
expand
(
"tok_name"
,
[
[
MBART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
MARIAN_TINY
,
...
@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum"
...
@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum"
BART_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
PEGASUS_XSUM
,
],
],
)
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
@
slow
def
test_seq2seq_dataset_truncation
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
max_src_len
=
4
...
@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name):
...
@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name):
break
# No need to test every batch
break
# No need to test every batch
@
parameterized
.
expand
([
BART_TINY
,
BERT_BASE_CASED
])
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
self
,
tok
):
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
trunc_target
=
4
...
@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok):
...
@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok):
assert
max_len_target
>
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
break
# No need to test every batch
def
test_pack_dataset
(
self
):
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
tmp_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
))
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
save_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
...
@@ -112,12 +110,11 @@ def test_pack_dataset():
...
@@ -112,12 +110,11 @@ def test_pack_dataset():
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
(
self
):
def
test_dynamic_batch_size
():
if
not
FAIRSEQ_AVAILABLE
:
if
not
FAIRSEQ_AVAILABLE
:
return
return
ds
,
max_tokens
,
tokenizer
=
_get_dataset
(
max_len
=
64
)
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
...
@@ -138,9 +135,8 @@ def test_dynamic_batch_size():
...
@@ -138,9 +135,8 @@ def test_dynamic_batch_size():
if
failures
:
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
(
self
):
def
test_sortish_sampler_reduces_padding
():
ds
,
_
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
512
)
ds
,
_
,
tokenizer
=
_get_dataset
(
max_len
=
512
)
bs
=
2
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
...
@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding():
...
@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding():
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
self
,
n_obs
=
1000
,
max_len
=
128
):
def
_get_dataset
(
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
max_tokens
=
max_len
*
2
*
64
...
@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128):
...
@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128):
)
)
return
ds
,
max_tokens
,
tokenizer
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
(
self
):
def
test_distributed_sortish_sampler_splits_indices_between_procs
():
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
()
ds
,
max_tokens
,
tokenizer
=
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
parameterized
.
expand
(
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
[
MBART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
MARIAN_TINY
,
...
@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
...
@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
BART_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
PEGASUS_XSUM
,
],
],
)
)
def
test_dataset_kwargs
(
tok_name
):
def
test_dataset_kwargs
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
tokenizer
,
data_dir
=
make_test_data_dir
(),
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
),
type_path
=
"train"
,
type_path
=
"train"
,
max_source_length
=
4
,
max_source_length
=
4
,
max_target_length
=
8
,
max_target_length
=
8
,
...
@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name):
...
@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name):
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
else
:
train_dataset
=
Seq2SeqDataset
(
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
)
)
kwargs
=
train_dataset
.
dataset_kwargs
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
9f7b2b24
import
os
import
os
import
sys
import
sys
import
tempfile
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
from
transformers.trainer_utils
import
set_seed
...
@@ -15,18 +14,18 @@ set_seed(42)
...
@@ -15,18 +14,18 @@ set_seed(42)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
def
test_finetune_trainer
():
class
TestFinetuneTrainer
(
TestCasePlus
):
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
def
test_finetune_trainer
(
self
):
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
assert
"eval_bleu"
in
first_step_stats
@
slow
@
slow
def
test_finetune_trainer_slow
(
self
):
def
test_finetune_trainer_slow
():
# There is a missing call to __init__process_group somewhere
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
...
@@ -43,10 +42,9 @@ def test_finetune_trainer_slow():
...
@@ -43,10 +42,9 @@ def test_finetune_trainer_slow():
assert
"test_generations.txt"
in
contents
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
assert
"test_results.json"
in
contents
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
argv
=
f
"""
argv
=
f
"""
--model_name_or_path
{
model_name
}
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--data_dir
{
data_dir
}
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
9f7b2b24
...
@@ -3,7 +3,6 @@ import logging
...
@@ -3,7 +3,6 @@ import logging
import
os
import
os
import
sys
import
sys
import
tempfile
import
tempfile
import
unittest
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -15,11 +14,12 @@ import lightning_base
...
@@ -15,11 +14,12 @@ import lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
distillation
import
distill_main
from
distillation
import
distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
parameterized
import
parameterized
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval_search
import
run_search
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
TestCasePlus
,
require_torch_and_cuda
,
slow
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
...
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
...
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers"
:
1
,
"student_decoder_layers"
:
1
,
"val_check_interval"
:
1.0
,
"val_check_interval"
:
1.0
,
"output_dir"
:
""
,
"output_dir"
:
""
,
"fp16"
:
False
,
# TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"fp16"
:
False
,
# TODO
(SS)
: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher"
:
False
,
"no_teacher"
:
False
,
"fp16_opt_level"
:
"O1"
,
"fp16_opt_level"
:
"O1"
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
...
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
...
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers"
:
1
,
"student_encoder_layers"
:
1
,
"freeze_encoder"
:
False
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
"auto_scale_batch_size"
:
False
,
"overwrite_output_dir"
:
False
,
}
}
...
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
...
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
def
make_test_data_dir
(
**
kwargs
):
def
make_test_data_dir
(
tmp_dir
):
tmp_dir
=
Path
(
tempfile
.
mkdtemp
(
**
kwargs
))
for
split
in
[
"train"
,
"val"
,
"test"
]:
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.target"
),
SUMMARIES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.target"
),
SUMMARIES
)
return
tmp_dir
return
tmp_dir
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
class
TestSummarizationDistiller
(
TestCase
Plus
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
...
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures
.
append
(
m
)
failures
.
append
(
m
)
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
@
require_multigpu
@
unittest
.
skip
(
"Broken at the moment"
)
def
test_multigpu
(
self
):
updates
=
dict
(
no_teacher
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
sortish_sampler
=
True
,
)
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
def
test_distill_no_teacher
(
self
):
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
)
...
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
assertEqual
(
1
,
len
(
ckpts
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
examples
=
lmap
(
str
.
strip
,
model
.
hparams
.
data_dir
.
joinpath
(
"test.source"
).
open
().
readlines
())
examples
=
lmap
(
str
.
strip
,
Path
(
model
.
hparams
.
data_dir
)
.
joinpath
(
"test.source"
).
open
().
readlines
())
out_path
=
tempfile
.
mktemp
()
out_path
=
tempfile
.
mktemp
()
# XXX: not being cleaned up
generate_summaries_or_translations
(
examples
,
out_path
,
str
(
model
.
output_dir
/
"best_tfmr"
))
generate_summaries_or_translations
(
examples
,
out_path
,
str
(
model
.
output_dir
/
"best_tfmr"
))
self
.
assertTrue
(
Path
(
out_path
).
exists
())
self
.
assertTrue
(
Path
(
out_path
).
exists
())
out_path_new
=
tempfile
.
mkdtemp
()
out_path_new
=
self
.
get_auto_remove_tmp_dir
()
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
...
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
)
)
default_updates
.
update
(
updates
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
data_dir
=
tmp_dir
,
output_dir
=
output_dir
,
**
default_updates
)
args_d
.
update
(
data_dir
=
tmp_dir
,
output_dir
=
output_dir
,
**
default_updates
)
model
=
distill_main
(
argparse
.
Namespace
(
**
args_d
))
model
=
distill_main
(
argparse
.
Namespace
(
**
args_d
))
...
@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase):
return
model
return
model
def
run_eval_tester
(
model
):
class
TestTheRest
(
TestCasePlus
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
def
run_eval_tester
(
self
,
model
):
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
assert
not
output_file_name
.
exists
()
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
_dump_articles
(
input_file_name
,
articles
)
_dump_articles
(
input_file_name
,
articles
)
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
score_path
=
str
(
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
f
"""
testargs
=
f
"""
run_eval_search.py
run_eval_search.py
...
@@ -301,27 +292,24 @@ def run_eval_tester(model):
...
@@ -301,27 +292,24 @@ def run_eval_tester(model):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
assert
Path
(
output_file_name
).
exists
()
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
# os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
():
run_eval_tester
(
T5_TINY
)
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
(
self
):
self
.
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
# any extra models should go into the list here - can be slow
@
slow
@
parameterized
.
expand
([
BART_TINY
,
MBART_TINY
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
BART_TINY
,
MBART_TINY
])
@
slow
def
test_run_eval_slow
(
model
):
def
test_run_eval_slow
(
self
,
model
):
run_eval_tester
(
model
)
self
.
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
slow
@
parameterized
.
expand
([
T5_TINY
,
MBART_TINY
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
MBART_TINY
])
@
slow
def
test_run_eval_search
(
model
):
def
test_run_eval_search
(
self
,
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
assert
not
output_file_name
.
exists
()
...
@@ -334,7 +322,7 @@ def test_run_eval_search(model):
...
@@ -334,7 +322,7 @@ def test_run_eval_search(model):
],
],
}
}
tmp_dir
=
Path
(
tempfile
.
mkdtemp
())
tmp_dir
=
Path
(
self
.
get_auto_remove_tmp_dir
())
score_path
=
str
(
tmp_dir
/
"scores.json"
)
score_path
=
str
(
tmp_dir
/
"scores.json"
)
reference_path
=
str
(
tmp_dir
/
"val.target"
)
reference_path
=
str
(
tmp_dir
/
"val.target"
)
_dump_articles
(
input_file_name
,
text
[
"en"
])
_dump_articles
(
input_file_name
,
text
[
"en"
])
...
@@ -367,18 +355,16 @@ def test_run_eval_search(model):
...
@@ -367,18 +355,16 @@ def test_run_eval_search(model):
assert
Path
(
output_file_name
).
exists
()
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
os
.
remove
(
Path
(
output_file_name
))
@
parameterized
.
expand
(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
)
def
test_finetune
(
model
):
def
test_finetune
(
self
,
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
args_d
.
update
(
data_dir
=
tmp_dir
,
data_dir
=
tmp_dir
,
model_name_or_path
=
model
,
model_name_or_path
=
model
,
...
@@ -423,12 +409,11 @@ def test_finetune(model):
...
@@ -423,12 +409,11 @@ def test_finetune(model):
assert
isinstance
(
example_batch
,
dict
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
(
self
):
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
args_d
.
update
(
args_d
.
update
(
data_dir
=
tmp_dir
,
data_dir
=
tmp_dir
,
...
@@ -445,7 +430,7 @@ def test_finetune_extra_model_args():
...
@@ -445,7 +430,7 @@ def test_finetune_extra_model_args():
# test models whose config includes the extra_model_args
# test models whose config includes the extra_model_args
model
=
BART_TINY
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d1
=
args_d
.
copy
()
args_d1
=
args_d
.
copy
()
args_d1
.
update
(
args_d1
.
update
(
model_name_or_path
=
model
,
model_name_or_path
=
model
,
...
@@ -461,7 +446,7 @@ def test_finetune_extra_model_args():
...
@@ -461,7 +446,7 @@ def test_finetune_extra_model_args():
# test models whose config doesn't include the extra_model_args
# test models whose config doesn't include the extra_model_args
model
=
T5_TINY
model
=
T5_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_2_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d2
=
args_d
.
copy
()
args_d2
=
args_d
.
copy
()
args_d2
.
update
(
args_d2
.
update
(
model_name_or_path
=
model
,
model_name_or_path
=
model
,
...
@@ -474,15 +459,14 @@ def test_finetune_extra_model_args():
...
@@ -474,15 +459,14 @@ def test_finetune_extra_model_args():
model
=
main
(
args
)
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
def
test_finetune_lr_schedulers
(
self
):
def
test_finetune_lr_schedulers
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
model
=
BART_TINY
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
args_d
.
update
(
data_dir
=
tmp_dir
,
data_dir
=
tmp_dir
,
...
@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers():
...
@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers():
args_d1
[
"lr_scheduler"
]
=
supported_param
args_d1
[
"lr_scheduler"
]
=
supported_param
args
=
argparse
.
Namespace
(
**
args_d1
)
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
model
=
main
(
args
)
assert
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
,
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
assert
(
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
),
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment