Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f7b2b24
Unverified
Commit
9f7b2b24
authored
Oct 17, 2020
by
Stas Bekman
Committed by
GitHub
Oct 17, 2020
Browse files
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
parent
dc552b9b
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
674 additions
and
694 deletions
+674
-694
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+162
-162
examples/seq2seq/test_data/wmt_en_ro/train.len
examples/seq2seq/test_data/wmt_en_ro/train.len
+0
-0
examples/seq2seq/test_data/wmt_en_ro/val.len
examples/seq2seq/test_data/wmt_en_ro/val.len
+0
-0
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+191
-195
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+62
-64
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+259
-273
No files found.
examples/seq2seq/test_bash_script.py
View file @
9f7b2b24
...
...
@@ -3,7 +3,6 @@
import
argparse
import
os
import
sys
import
tempfile
from
pathlib
import
Path
from
unittest.mock
import
patch
...
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
load_json
...
...
@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_model_download
():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_mbart"
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
[
"finetune.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
():
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
16
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
[
"distillation.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=1e-3"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
class
TestAll
(
TestCasePlus
):
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_model_download
(
self
):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
self
.
get_auto_remove_tmp_dir
()
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
[
"finetune.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
(
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_opus_mt_distill_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$BS"
:
16
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
self
.
get_auto_remove_tmp_dir
()
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
epochs
=
6
testargs
=
(
[
"distillation.py"
]
+
bash_script
.
split
()
+
[
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--learning_rate=1e-3"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
]
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
examples/seq2seq/test_data/wmt_en_ro/train.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_data/wmt_en_ro/val.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_datasets.py
View file @
9f7b2b24
import
os
import
tempfile
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -7,11 +6,12 @@ import pytest
from
torch.utils.data
import
DataLoader
from
pack_dataset
import
pack_data_dir
from
parameterized
import
parameterized
from
save_len_file
import
save_len_file
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
transformers
import
AutoTokenizer
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
...
...
@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
@
slow
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
max_tgt_len
=
8
assert
max_len_target
>
max_src_len
# Will be truncated
assert
max_len_source
>
max_src_len
# Will be truncated
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
max_src_len
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
class
TestAll
(
TestCasePlus
):
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
isinstance
(
batch
,
dict
)
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
# show that targets are the same len
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
if
tok_name
!=
MBART_TINY
:
continue
# check language codes in correct place
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
():
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
():
ds
,
_
,
tokenizer
=
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
():
ds
,
max_tokens
,
tokenizer
=
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
@
slow
def
test_seq2seq_dataset_truncation
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
max_tgt_len
=
8
assert
max_len_target
>
max_src_len
# Will be truncated
assert
max_len_source
>
max_src_len
# Will be truncated
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data
_dir
()
,
data_dir
=
tmp
_dir
,
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
max_source_length
=
max_src_len
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
isinstance
(
batch
,
dict
)
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
# show that targets are the same len
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
if
tok_name
!=
MBART_TINY
:
continue
# check language codes in correct place
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
parameterized
.
expand
([
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
self
,
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
(
self
):
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
(
self
):
ds
,
_
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
self
,
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
(
self
):
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
examples/seq2seq/test_finetune_trainer.py
View file @
9f7b2b24
import
os
import
sys
import
tempfile
from
unittest.mock
import
patch
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
...
...
@@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
def
test_finetune_trainer
():
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
class
TestFinetuneTrainer
(
TestCasePlus
):
def
test_finetune_trainer
(
self
):
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
@
slow
def
test_finetune_trainer_slow
():
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
self
.
get_auto_remove_tmp_dir
()
argv
=
f
"""
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
argv
=
f
"""
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
return
output_dir
return
output_dir
examples/seq2seq/test_seq2seq_examples.py
View file @
9f7b2b24
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment