Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f7b2b24
Unverified
Commit
9f7b2b24
authored
Oct 17, 2020
by
Stas Bekman
Committed by
GitHub
Oct 17, 2020
Browse files
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
parent
dc552b9b
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
674 additions
and
694 deletions
+674
-694
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+162
-162
examples/seq2seq/test_data/wmt_en_ro/train.len
examples/seq2seq/test_data/wmt_en_ro/train.len
+0
-0
examples/seq2seq/test_data/wmt_en_ro/val.len
examples/seq2seq/test_data/wmt_en_ro/val.len
+0
-0
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+191
-195
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+62
-64
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+259
-273
No files found.
examples/seq2seq/test_bash_script.py
View file @
9f7b2b24
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
argparse
import
argparse
import
os
import
os
import
sys
import
sys
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
load_json
from
utils
import
load_json
...
@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
...
@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
slow
def
test_model_download
():
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
def
test_model_download
(
self
):
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
():
def
test_train_mbart_cc25_enro_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# Download is 120MB in previous test.
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
}
# Clean up bash script
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_mbart"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
()
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
testargs
=
(
[
"finetune.py"
]
[
"finetune.py"
]
+
bash_script
.
split
()
+
bash_script
.
split
()
+
[
+
[
f
"--output_dir=
{
output_dir
}
"
,
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--gpus=1"
,
"--learning_rate=3e-1"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
"--val_check_interval=1.0"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
]
]
)
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
main
(
args
)
model
=
main
(
args
)
# Check metrics
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
(
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
)
# +1 accounts for val_sanity_check
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
# check lightning ckpt can be loaded and has a reasonable statedict
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
# check lightning ckpt can be loaded and has a reasonable statedict
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
contents
=
os
.
listdir
(
output_dir
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
assert
expected_key
in
ckpt
[
"state_dict"
]
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
# TODO: turn on args.do_predict when PL bug fixed.
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
# TODO: turn on args.do_predict when PL bug fixed.
assert
"test_generations.txt"
in
contents
if
args
.
do_predict
:
assert
"test_results.txt"
in
contents
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
# assert len(metrics["val"]) == desired_n_evals
assert
"test_generations.txt"
in
contents
assert
len
(
metrics
[
"test"
])
==
1
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
timeout_decorator
.
timeout
(
600
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
slow
def
test_opus_mt_distill_script
():
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
def
test_opus_mt_distill_script
(
self
):
env_vars_to_replace
=
{
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
"--fp16_opt_level=O1"
:
""
,
env_vars_to_replace
=
{
"$MAX_LEN"
:
128
,
"--fp16_opt_level=O1"
:
""
,
"$BS"
:
16
,
"$MAX_LEN"
:
128
,
"$GAS"
:
1
,
"$BS"
:
16
,
"$ENRO_DIR"
:
data_dir
,
"$GAS"
:
1
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"$ENRO_DIR"
:
data_dir
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
}
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
# Clean up bash script
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
bash_script
=
(
)
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
for
k
,
v
in
env_vars_to_replace
.
items
():
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
output_dir
=
self
.
get_auto_remove_tmp_dir
()
epochs
=
6
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
testargs
=
(
epochs
=
6
[
"distillation.py"
]
testargs
=
(
+
bash_script
.
split
()
[
"distillation.py"
]
+
[
+
bash_script
.
split
()
f
"--output_dir=
{
output_dir
}
"
,
+
[
"--gpus=1"
,
f
"--output_dir=
{
output_dir
}
"
,
"--learning_rate=1e-3"
,
"--gpus=1"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--learning_rate=1e-3"
,
"--warmup_steps=10"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--val_check_interval=1.0"
,
"--warmup_steps=10"
,
]
"--val_check_interval=1.0"
,
)
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
)
parser
=
argparse
.
ArgumentParser
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
argparse
.
ArgumentParser
()
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
()
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
.
do_predict
=
False
args
=
parser
.
parse_args
()
# assert args.gpus == gpus THIS BREAKS for multigpu
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
# Check metrics
first_step_stats
=
metrics
[
"val"
][
0
]
metrics
=
load_json
(
model
.
metrics_save_path
)
last_step_stats
=
metrics
[
"val"
][
-
1
]
first_step_stats
=
metrics
[
"val"
][
0
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
# check lightning ckpt can be loaded and has a reasonable statedict
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
contents
=
os
.
listdir
(
output_dir
)
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
assert
expected_key
in
ckpt
[
"state_dict"
]
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
# TODO: turn on args.do_predict when PL bug fixed.
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
if
args
.
do_predict
:
assert
"test_generations.txt"
in
contents
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_results.txt"
in
contents
assert
"test_generations.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
"test_results.txt"
in
contents
assert
len
(
metrics
[
"test"
])
==
1
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
examples/seq2seq/test_data/wmt_en_ro/train.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_data/wmt_en_ro/val.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_datasets.py
View file @
9f7b2b24
import
os
import
os
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -7,11 +6,12 @@ import pytest
...
@@ -7,11 +6,12 @@ import pytest
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
pack_dataset
import
pack_data_dir
from
pack_dataset
import
pack_data_dir
from
parameterized
import
parameterized
from
save_len_file
import
save_len_file
from
save_len_file
import
save_len_file
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
...
@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
...
@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
PEGASUS_XSUM
=
"google/pegasus-xsum"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
parametrize
(
@
parameterized
.
expand
(
"tok_name"
,
[
[
MBART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
MARIAN_TINY
,
T5_TINY
,
T5_TINY
,
BART_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
PEGASUS_XSUM
,
],
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
max_tgt_len
=
8
assert
max_len_target
>
max_src_len
# Will be truncated
assert
max_len_source
>
max_src_len
# Will be truncated
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
max_src_len
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
)
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
@
slow
for
batch
in
dataloader
:
def
test_seq2seq_dataset_truncation
(
self
,
tok_name
):
assert
isinstance
(
batch
,
dict
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
# show that articles were trimmed.
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
# show that targets are the same len
max_src_len
=
4
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
max_tgt_len
=
8
if
tok_name
!=
MBART_TINY
:
assert
max_len_target
>
max_src_len
# Will be truncated
continue
assert
max_len_source
>
max_src_len
# Will be truncated
# check language codes in correct place
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
():
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
():
ds
,
_
,
tokenizer
=
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
():
ds
,
max_tokens
,
tokenizer
=
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
tokenizer
,
data_dir
=
make_test_data
_dir
()
,
data_dir
=
tmp
_dir
,
type_path
=
"train"
,
type_path
=
"train"
,
max_source_length
=
4
,
max_source_length
=
max_src_len
,
max_target_length
=
8
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
"EN"
,
src_lang
=
src_lang
,
tgt_lang
=
"FR"
,
tgt_lang
=
tgt_lang
,
)
)
kwargs
=
train_dataset
.
dataset_kwargs
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
for
batch
in
dataloader
:
else
:
assert
isinstance
(
batch
,
dict
)
train_dataset
=
Seq2SeqDataset
(
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
# show that targets are the same len
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
if
tok_name
!=
MBART_TINY
:
continue
# check language codes in correct place
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
parameterized
.
expand
([
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
self
,
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
)
kwargs
=
train_dataset
.
dataset_kwargs
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
for
batch
in
dataloader
:
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
(
self
):
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
(
self
):
ds
,
_
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
self
,
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
(
self
):
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
examples/seq2seq/test_finetune_trainer.py
View file @
9f7b2b24
import
os
import
os
import
sys
import
sys
import
tempfile
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
from
transformers.trainer_utils
import
set_seed
...
@@ -15,72 +14,71 @@ set_seed(42)
...
@@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
def
test_finetune_trainer
():
class
TestFinetuneTrainer
(
TestCasePlus
):
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
def
test_finetune_trainer
(
self
):
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
first_step_stats
=
eval_metrics
[
0
]
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
assert
"eval_bleu"
in
first_step_stats
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
@
slow
# Check metrics
def
test_finetune_trainer_slow
():
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
# There is a missing call to __init__process_group somewhere
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
# Check metrics
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
# test if do_predict saves generations and metrics
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
# test if do_predict saves generations and metrics
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
contents
=
os
.
listdir
(
output_dir
)
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
output_dir
=
self
.
get_auto_remove_tmp_dir
()
assert
"test_generations.txt"
in
contents
argv
=
f
"""
assert
"test_results.json"
in
contents
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
return
output_dir
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
argv
=
f
"""
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
return
output_dir
examples/seq2seq/test_seq2seq_examples.py
View file @
9f7b2b24
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment