Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f7b2b24
Unverified
Commit
9f7b2b24
authored
Oct 17, 2020
by
Stas Bekman
Committed by
GitHub
Oct 17, 2020
Browse files
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
parent
dc552b9b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
674 additions
and
694 deletions
+674
-694
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+162
-162
examples/seq2seq/test_data/wmt_en_ro/train.len
examples/seq2seq/test_data/wmt_en_ro/train.len
+0
-0
examples/seq2seq/test_data/wmt_en_ro/val.len
examples/seq2seq/test_data/wmt_en_ro/val.len
+0
-0
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+191
-195
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+62
-64
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+259
-273
No files found.
examples/seq2seq/test_bash_script.py
View file @
9f7b2b24
...
...
@@ -3,7 +3,6 @@
import
argparse
import
os
import
sys
import
tempfile
from
pathlib
import
Path
from
unittest.mock
import
patch
...
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
load_json
...
...
@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_model_download
():
class
TestAll
(
TestCasePlus
):
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_model_download
(
self
):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
():
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
...
...
@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script():
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_mbart"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
...
...
@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script():
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
(
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
...
...
@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script():
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
():
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
...
...
@@ -131,7 +131,7 @@ def test_opus_mt_distill_script():
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
...
...
examples/seq2seq/test_data/wmt_en_ro/train.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_data/wmt_en_ro/val.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_datasets.py
View file @
9f7b2b24
import
os
import
tempfile
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -7,11 +6,12 @@ import pytest
from
torch.utils.data
import
DataLoader
from
pack_dataset
import
pack_data_dir
from
parameterized
import
parameterized
from
save_len_file
import
save_len_file
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
transformers
import
AutoTokenizer
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
...
...
@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
@
slow
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
class
TestAll
(
TestCasePlus
):
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
...
...
@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum"
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
)
@
slow
def
test_seq2seq_dataset_truncation
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
...
...
@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name):
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
@
parameterized
.
expand
([
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
self
,
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
...
...
@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok):
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
():
def
test_pack_dataset
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
tmp_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
))
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
save_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
...
...
@@ -112,12 +110,11 @@ def test_pack_dataset():
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
():
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
(
self
):
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
_get_dataset
(
max_len
=
64
)
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
...
...
@@ -138,9 +135,8 @@ def test_dynamic_batch_size():
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
():
ds
,
_
,
tokenizer
=
_get_dataset
(
max_len
=
512
)
def
test_sortish_sampler_reduces_padding
(
self
):
ds
,
_
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
...
...
@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding():
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
n_obs
=
1000
,
max_len
=
128
):
def
_get_dataset
(
self
,
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
...
...
@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128):
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
():
ds
,
max_tokens
,
tokenizer
=
_get_dataset
()
def
test_distributed_sortish_sampler_splits_indices_between_procs
(
self
):
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
...
...
@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
)
def
test_dataset_kwargs
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
...
...
@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name):
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
9f7b2b24
import
os
import
sys
import
tempfile
from
unittest.mock
import
patch
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
...
...
@@ -15,18 +14,18 @@ set_seed(42)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
def
test_finetune_trainer
():
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
class
TestFinetuneTrainer
(
TestCasePlus
):
def
test_finetune_trainer
(
self
):
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
def
test_finetune_trainer_slow
():
@
slow
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
...
...
@@ -43,10 +42,9 @@ def test_finetune_trainer_slow():
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
argv
=
f
"""
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
9f7b2b24
...
...
@@ -3,7 +3,6 @@ import logging
import
os
import
sys
import
tempfile
import
unittest
from
pathlib
import
Path
from
unittest.mock
import
patch
...
...
@@ -15,11 +14,12 @@ import lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
distillation
import
distill_main
from
finetune
import
SummarizationModule
,
main
from
parameterized
import
parameterized
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
TestCasePlus
,
require_torch_and_cuda
,
slow
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
...
...
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers"
:
1
,
"val_check_interval"
:
1.0
,
"output_dir"
:
""
,
"fp16"
:
False
,
# TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"fp16"
:
False
,
# TODO
(SS)
: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher"
:
False
,
"fp16_opt_level"
:
"O1"
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
...
...
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers"
:
1
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
"overwrite_output_dir"
:
False
,
}
...
...
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
def
make_test_data_dir
(
**
kwargs
):
tmp_dir
=
Path
(
tempfile
.
mkdtemp
(
**
kwargs
))
def
make_test_data_dir
(
tmp_dir
):
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.target"
),
SUMMARIES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.target"
),
SUMMARIES
)
return
tmp_dir
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
class
TestSummarizationDistiller
(
TestCase
Plus
):
@
classmethod
def
setUpClass
(
cls
):
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
...
...
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures
.
append
(
m
)
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
@
require_multigpu
@
unittest
.
skip
(
"Broken at the moment"
)
def
test_multigpu
(
self
):
updates
=
dict
(
no_teacher
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
sortish_sampler
=
True
,
)
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
self
.
_test_distiller_cli
(
updates
)
...
...
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
assertEqual
(
1
,
len
(
ckpts
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
examples
=
lmap
(
str
.
strip
,
model
.
hparams
.
data_dir
.
joinpath
(
"test.source"
).
open
().
readlines
())
out_path
=
tempfile
.
mktemp
()
examples
=
lmap
(
str
.
strip
,
Path
(
model
.
hparams
.
data_dir
)
.
joinpath
(
"test.source"
).
open
().
readlines
())
out_path
=
tempfile
.
mktemp
()
# XXX: not being cleaned up
generate_summaries_or_translations
(
examples
,
out_path
,
str
(
model
.
output_dir
/
"best_tfmr"
))
self
.
assertTrue
(
Path
(
out_path
).
exists
())
out_path_new
=
tempfile
.
mkdtemp
()
out_path_new
=
self
.
get_auto_remove_tmp_dir
()
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
...
...
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
data_dir
=
tmp_dir
,
output_dir
=
output_dir
,
**
default_updates
)
model
=
distill_main
(
argparse
.
Namespace
(
**
args_d
))
...
...
@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase):
return
model
def
run_eval_tester
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
class
TestTheRest
(
TestCasePlus
):
def
run_eval_tester
(
self
,
model
):
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
_dump_articles
(
input_file_name
,
articles
)
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
score_path
=
str
(
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
f
"""
run_eval_search.py
...
...
@@ -301,27 +292,24 @@ def run_eval_tester(model):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
():
run_eval_tester
(
T5_TINY
)
# os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
(
self
):
self
.
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
BART_TINY
,
MBART_TINY
])
def
test_run_eval_slow
(
model
):
run_eval_tester
(
model
)
# any extra models should go into the list here - can be slow
@
parameterized
.
expand
([
BART_TINY
,
MBART_TINY
])
@
slow
def
test_run_eval_slow
(
self
,
model
):
self
.
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
MBART_TINY
])
def
test_run_eval_search
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
parameterized
.
expand
([
T5_TINY
,
MBART_TINY
])
@
slow
def
test_run_eval_search
(
self
,
model
):
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
...
...
@@ -334,7 +322,7 @@ def test_run_eval_search(model):
],
}
tmp_dir
=
Path
(
tempfile
.
mkdtemp
())
tmp_dir
=
Path
(
self
.
get_auto_remove_tmp_dir
())
score_path
=
str
(
tmp_dir
/
"scores.json"
)
reference_path
=
str
(
tmp_dir
/
"val.target"
)
_dump_articles
(
input_file_name
,
text
[
"en"
])
...
...
@@ -367,18 +355,16 @@ def test_run_eval_search(model):
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
@
pytest
.
mark
.
parametrize
(
"model"
,
@
parameterized
.
expand
(
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
def
test_finetune
(
model
):
)
def
test_finetune
(
self
,
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_name_or_path
=
model
,
...
...
@@ -423,12 +409,11 @@ def test_finetune(model):
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
():
def
test_finetune_extra_model_args
(
self
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
args_d
.
update
(
data_dir
=
tmp_dir
,
...
...
@@ -445,7 +430,7 @@ def test_finetune_extra_model_args():
# test models whose config includes the extra_model_args
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d1
=
args_d
.
copy
()
args_d1
.
update
(
model_name_or_path
=
model
,
...
...
@@ -461,7 +446,7 @@ def test_finetune_extra_model_args():
# test models whose config doesn't include the extra_model_args
model
=
T5_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_2_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d2
=
args_d
.
copy
()
args_d2
.
update
(
model_name_or_path
=
model
,
...
...
@@ -474,15 +459,14 @@ def test_finetune_extra_model_args():
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
def
test_finetune_lr_schedulers
():
def
test_finetune_lr_schedulers
(
self
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
model
=
BART_TINY
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
data_dir
=
tmp_dir
,
...
...
@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers():
args_d1
[
"lr_scheduler"
]
=
supported_param
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
assert
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
,
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
assert
(
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
),
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment