Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f7b2b24
Unverified
Commit
9f7b2b24
authored
Oct 17, 2020
by
Stas Bekman
Committed by
GitHub
Oct 17, 2020
Browse files
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
parent
dc552b9b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
674 additions
and
694 deletions
+674
-694
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+162
-162
examples/seq2seq/test_data/wmt_en_ro/train.len
examples/seq2seq/test_data/wmt_en_ro/train.len
+0
-0
examples/seq2seq/test_data/wmt_en_ro/val.len
examples/seq2seq/test_data/wmt_en_ro/val.len
+0
-0
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+191
-195
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+62
-64
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+259
-273
No files found.
examples/seq2seq/test_bash_script.py
View file @
9f7b2b24
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
argparse
import
argparse
import
os
import
os
import
sys
import
sys
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
...
@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
load_json
from
utils
import
load_json
...
@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
...
@@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
slow
def
test_model_download
():
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
def
test_model_download
(
self
):
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
@
timeout_decorator
.
timeout
(
120
)
@
timeout_decorator
.
timeout
(
120
)
@
slow
@
slow
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
def
test_train_mbart_cc25_enro_script
():
def
test_train_mbart_cc25_enro_script
(
self
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace
=
{
env_vars_to_replace
=
{
"--fp16_opt_level=O1"
:
""
,
"--fp16_opt_level=O1"
:
""
,
"$MAX_LEN"
:
128
,
"$MAX_LEN"
:
128
,
"$BS"
:
4
,
"$BS"
:
4
,
"$GAS"
:
1
,
"$GAS"
:
1
,
"$ENRO_DIR"
:
data_dir
,
"$ENRO_DIR"
:
data_dir
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
"facebook/mbart-large-cc25"
:
MODEL_NAME
,
# Download is 120MB in previous test.
# Download is 120MB in previous test.
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
}
# Clean up bash script
# Clean up bash script
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
Path
(
"examples/seq2seq/train_mbart_cc25_enro.sh"
).
open
().
read
().
split
(
"finetune.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
for
k
,
v
in
env_vars_to_replace
.
items
():
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_mbart"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
()
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
""
)
testargs
=
(
testargs
=
(
[
"finetune.py"
]
[
"finetune.py"
]
+
bash_script
.
split
()
+
bash_script
.
split
()
+
[
+
[
f
"--output_dir=
{
output_dir
}
"
,
f
"--output_dir=
{
output_dir
}
"
,
"--gpus=1"
,
"--gpus=1"
,
"--learning_rate=3e-1"
,
"--learning_rate=3e-1"
,
"--warmup_steps=0"
,
"--warmup_steps=0"
,
"--val_check_interval=1.0"
,
"--val_check_interval=1.0"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
"--tokenizer_name=facebook/mbart-large-en-ro"
,
]
]
)
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
main
(
args
)
model
=
main
(
args
)
# Check metrics
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
(
len
(
metrics
[
"val"
])
==
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
)
# +1 accounts for val_sanity_check
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
# check lightning ckpt can be loaded and has a reasonable statedict
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
contents
=
os
.
listdir
(
output_dir
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
# check lightning ckpt can be loaded and has a reasonable statedict
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
contents
=
os
.
listdir
(
output_dir
)
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
assert
expected_key
in
ckpt
[
"state_dict"
]
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
expected_key
in
ckpt
[
"state_dict"
]
# TODO: turn on args.do_predict when PL bug fixed.
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
if
args
.
do_predict
:
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
# TODO: turn on args.do_predict when PL bug fixed.
assert
"test_generations.txt"
in
contents
if
args
.
do_predict
:
assert
"test_results.txt"
in
contents
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
# assert len(metrics["val"]) == desired_n_evals
assert
"test_generations.txt"
in
contents
assert
len
(
metrics
[
"test"
])
==
1
assert
"test_results.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
@
timeout_decorator
.
timeout
(
600
)
@
slow
@
timeout_decorator
.
timeout
(
600
)
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
@
slow
def
test_opus_mt_distill_script
():
@
pytest
.
mark
.
skipif
(
not
CUDA_AVAILABLE
,
reason
=
"too slow to run on CPU"
)
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
def
test_opus_mt_distill_script
(
self
):
env_vars_to_replace
=
{
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
"--fp16_opt_level=O1"
:
""
,
env_vars_to_replace
=
{
"$MAX_LEN"
:
128
,
"--fp16_opt_level=O1"
:
""
,
"$BS"
:
16
,
"$MAX_LEN"
:
128
,
"$GAS"
:
1
,
"$BS"
:
16
,
"$ENRO_DIR"
:
data_dir
,
"$GAS"
:
1
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
"$ENRO_DIR"
:
data_dir
,
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
"$m"
:
"sshleifer/student_marian_en_ro_6_1"
,
}
"val_check_interval=0.25"
:
"val_check_interval=1.0"
,
}
# Clean up bash script
bash_script
=
(
# Clean up bash script
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
bash_script
=
(
)
Path
(
"examples/seq2seq/distil_marian_no_teacher.sh"
).
open
().
read
().
split
(
"distillation.py"
)[
1
].
strip
()
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
bash_script
=
bash_script
.
replace
(
"
\\\n
"
,
""
).
strip
().
replace
(
'"$@"'
,
""
)
bash_script
=
bash_script
.
replace
(
"--fp16 "
,
" "
)
for
k
,
v
in
env_vars_to_replace
.
items
():
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
for
k
,
v
in
env_vars_to_replace
.
items
():
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
bash_script
=
bash_script
.
replace
(
k
,
str
(
v
))
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
output_dir
=
self
.
get_auto_remove_tmp_dir
()
epochs
=
6
bash_script
=
bash_script
.
replace
(
"--fp16"
,
""
)
testargs
=
(
epochs
=
6
[
"distillation.py"
]
testargs
=
(
+
bash_script
.
split
()
[
"distillation.py"
]
+
[
+
bash_script
.
split
()
f
"--output_dir=
{
output_dir
}
"
,
+
[
"--gpus=1"
,
f
"--output_dir=
{
output_dir
}
"
,
"--learning_rate=1e-3"
,
"--gpus=1"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--learning_rate=1e-3"
,
"--warmup_steps=10"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--val_check_interval=1.0"
,
"--warmup_steps=10"
,
]
"--val_check_interval=1.0"
,
)
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
)
parser
=
argparse
.
ArgumentParser
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
argparse
.
ArgumentParser
()
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
()
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
.
do_predict
=
False
args
=
parser
.
parse_args
()
# assert args.gpus == gpus THIS BREAKS for multigpu
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
model
=
distill_main
(
args
)
# Check metrics
metrics
=
load_json
(
model
.
metrics_save_path
)
# Check metrics
first_step_stats
=
metrics
[
"val"
][
0
]
metrics
=
load_json
(
model
.
metrics_save_path
)
last_step_stats
=
metrics
[
"val"
][
-
1
]
first_step_stats
=
metrics
[
"val"
][
0
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
>=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
first_step_stats
[
"val_avg_bleu"
]
<
last_step_stats
[
"val_avg_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
assert
1.0
>=
last_step_stats
[
"val_avg_gen_time"
]
# model hanging on generate. Maybe bad config was saved.
assert
isinstance
(
last_step_stats
[
f
"val_avg_
{
model
.
val_metric
}
"
],
float
)
# check lightning ckpt can be loaded and has a reasonable statedict
contents
=
os
.
listdir
(
output_dir
)
# check lightning ckpt can be loaded and has a reasonable statedict
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
contents
=
os
.
listdir
(
output_dir
)
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
ckpt_path
=
[
x
for
x
in
contents
if
x
.
endswith
(
".ckpt"
)][
0
]
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
full_path
=
os
.
path
.
join
(
args
.
output_dir
,
ckpt_path
)
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
ckpt
=
torch
.
load
(
full_path
,
map_location
=
"cpu"
)
assert
expected_key
in
ckpt
[
"state_dict"
]
expected_key
=
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
assert
expected_key
in
ckpt
[
"state_dict"
]
assert
ckpt
[
"state_dict"
][
"model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
].
dtype
==
torch
.
float32
# TODO: turn on args.do_predict when PL bug fixed.
if
args
.
do_predict
:
# TODO: turn on args.do_predict when PL bug fixed.
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
if
args
.
do_predict
:
assert
"test_generations.txt"
in
contents
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_results.txt"
in
contents
assert
"test_generations.txt"
in
contents
# assert len(metrics["val"]) == desired_n_evals
assert
"test_results.txt"
in
contents
assert
len
(
metrics
[
"test"
])
==
1
# assert len(metrics["val"]) == desired_n_evals
assert
len
(
metrics
[
"test"
])
==
1
examples/seq2seq/test_data/wmt_en_ro/train.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_data/wmt_en_ro/val.len
View file @
9f7b2b24
No preview for this file type
examples/seq2seq/test_datasets.py
View file @
9f7b2b24
import
os
import
os
import
tempfile
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -7,11 +6,12 @@ import pytest
...
@@ -7,11 +6,12 @@ import pytest
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
pack_dataset
import
pack_data_dir
from
pack_dataset
import
pack_data_dir
from
parameterized
import
parameterized
from
save_len_file
import
save_len_file
from
save_len_file
import
save_len_file
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
test_seq2seq_examples
import
ARTICLES
,
BART_TINY
,
MARIAN_TINY
,
MBART_TINY
,
SUMMARIES
,
T5_TINY
,
make_test_data_dir
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
from
utils
import
FAIRSEQ_AVAILABLE
,
DistributedSortishSampler
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
...
@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
...
@@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
PEGASUS_XSUM
=
"google/pegasus-xsum"
@
slow
class
TestAll
(
TestCasePlus
):
@
pytest
.
mark
.
parametrize
(
@
parameterized
.
expand
(
"tok_name"
,
[
[
MBART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
MARIAN_TINY
,
T5_TINY
,
T5_TINY
,
BART_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
PEGASUS_XSUM
,
],
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
max_src_len
=
4
max_tgt_len
=
8
assert
max_len_target
>
max_src_len
# Will be truncated
assert
max_len_source
>
max_src_len
# Will be truncated
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
max_src_len
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
)
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
@
slow
for
batch
in
dataloader
:
def
test_seq2seq_dataset_truncation
(
self
,
tok_name
):
assert
isinstance
(
batch
,
dict
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
# show that articles were trimmed.
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
# show that targets are the same len
max_src_len
=
4
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
max_tgt_len
=
8
if
tok_name
!=
MBART_TINY
:
assert
max_len_target
>
max_src_len
# Will be truncated
continue
assert
max_len_source
>
max_src_len
# Will be truncated
# check language codes in correct place
src_lang
,
tgt_lang
=
"ro_RO"
,
"de_DE"
# ignored for all but mbart, but never causes error.
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
for
batch
in
dataloader
:
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
())
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
tempfile
.
mkdtemp
(
prefix
=
"packed_"
))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
():
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
():
ds
,
_
,
tokenizer
=
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
():
ds
,
max_tokens
,
tokenizer
=
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
tokenizer
,
data_dir
=
make_test_data
_dir
()
,
data_dir
=
tmp
_dir
,
type_path
=
"train"
,
type_path
=
"train"
,
max_source_length
=
4
,
max_source_length
=
max_src_len
,
max_target_length
=
8
,
max_target_length
=
max_tgt_len
,
# ignored
src_lang
=
"EN"
,
src_lang
=
src_lang
,
tgt_lang
=
"FR"
,
tgt_lang
=
tgt_lang
,
)
)
kwargs
=
train_dataset
.
dataset_kwargs
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
for
batch
in
dataloader
:
else
:
assert
isinstance
(
batch
,
dict
)
train_dataset
=
Seq2SeqDataset
(
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_src_len
# show that targets are the same len
assert
batch
[
"labels"
].
shape
[
1
]
==
max_tgt_len
if
tok_name
!=
MBART_TINY
:
continue
# check language codes in correct place
batch
[
"decoder_input_ids"
]
=
shift_tokens_right
(
batch
[
"labels"
],
tokenizer
.
pad_token_id
)
assert
batch
[
"decoder_input_ids"
][
0
,
0
].
item
()
==
tokenizer
.
lang_code_to_id
[
tgt_lang
]
assert
batch
[
"decoder_input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
2
].
item
()
==
tokenizer
.
eos_token_id
assert
batch
[
"input_ids"
][
0
,
-
1
].
item
()
==
tokenizer
.
lang_code_to_id
[
src_lang
]
break
# No need to test every batch
@
parameterized
.
expand
([
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
self
,
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
max_len_source
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
ARTICLES
)
max_len_target
=
max
(
len
(
tokenizer
.
encode
(
a
))
for
a
in
SUMMARIES
)
trunc_target
=
4
train_dataset
=
LegacySeq2SeqDataset
(
tokenizer
,
data_dir
=
tmp_dir
,
type_path
=
"train"
,
max_source_length
=
20
,
max_target_length
=
trunc_target
,
)
)
kwargs
=
train_dataset
.
dataset_kwargs
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
2
,
collate_fn
=
train_dataset
.
collate_fn
)
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
for
batch
in
dataloader
:
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
assert
batch
[
"attention_mask"
].
shape
==
batch
[
"input_ids"
].
shape
# show that articles were trimmed.
assert
batch
[
"input_ids"
].
shape
[
1
]
==
max_len_source
assert
20
>=
batch
[
"input_ids"
].
shape
[
1
]
# trimmed significantly
# show that targets were truncated
assert
batch
[
"labels"
].
shape
[
1
]
==
trunc_target
# Truncated
assert
max_len_target
>
trunc_target
# Truncated
break
# No need to test every batch
def
test_pack_dataset
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/mbart-large-cc25"
)
tmp_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
orig_examples
=
tmp_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
save_dir
=
Path
(
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()))
pack_data_dir
(
tokenizer
,
tmp_dir
,
128
,
save_dir
)
orig_paths
=
{
x
.
name
for
x
in
tmp_dir
.
iterdir
()}
new_paths
=
{
x
.
name
for
x
in
save_dir
.
iterdir
()}
packed_examples
=
save_dir
.
joinpath
(
"train.source"
).
open
().
readlines
()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert
len
(
packed_examples
)
<
len
(
orig_examples
)
assert
len
(
packed_examples
)
==
1
assert
len
(
packed_examples
[
0
])
==
sum
(
len
(
x
)
for
x
in
orig_examples
)
assert
orig_paths
==
new_paths
@
pytest
.
mark
.
skipif
(
not
FAIRSEQ_AVAILABLE
,
reason
=
"This test requires fairseq"
)
def
test_dynamic_batch_size
(
self
):
if
not
FAIRSEQ_AVAILABLE
:
return
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
64
)
required_batch_size_multiple
=
64
batch_sampler
=
ds
.
make_dynamic_sampler
(
max_tokens
,
required_batch_size_multiple
=
required_batch_size_multiple
)
batch_sizes
=
[
len
(
x
)
for
x
in
batch_sampler
]
assert
len
(
set
(
batch_sizes
))
>
1
# it's not dynamic batch size if every batch is the same length
assert
sum
(
batch_sizes
)
==
len
(
ds
)
# no dropped or added examples
data_loader
=
DataLoader
(
ds
,
batch_sampler
=
batch_sampler
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
failures
=
[]
num_src_per_batch
=
[]
for
batch
in
data_loader
:
src_shape
=
batch
[
"input_ids"
].
shape
bs
=
src_shape
[
0
]
assert
bs
%
required_batch_size_multiple
==
0
or
bs
<
required_batch_size_multiple
num_src_tokens
=
np
.
product
(
batch
[
"input_ids"
].
shape
)
num_src_per_batch
.
append
(
num_src_tokens
)
if
num_src_tokens
>
(
max_tokens
*
1.1
):
failures
.
append
(
num_src_tokens
)
assert
num_src_per_batch
[
0
]
==
max
(
num_src_per_batch
)
if
failures
:
raise
AssertionError
(
f
"too many tokens in
{
len
(
failures
)
}
batches"
)
def
test_sortish_sampler_reduces_padding
(
self
):
ds
,
_
,
tokenizer
=
self
.
_get_dataset
(
max_len
=
512
)
bs
=
2
sortish_sampler
=
ds
.
make_sortish_sampler
(
bs
,
shuffle
=
False
)
naive_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
)
sortish_dl
=
DataLoader
(
ds
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
,
num_workers
=
2
,
sampler
=
sortish_sampler
)
pad
=
tokenizer
.
pad_token_id
def
count_pad_tokens
(
data_loader
,
k
=
"input_ids"
):
return
[
batch
[
k
].
eq
(
pad
).
sum
().
item
()
for
batch
in
data_loader
]
assert
sum
(
count_pad_tokens
(
sortish_dl
,
k
=
"labels"
))
<
sum
(
count_pad_tokens
(
naive_dl
,
k
=
"labels"
))
assert
sum
(
count_pad_tokens
(
sortish_dl
))
<
sum
(
count_pad_tokens
(
naive_dl
))
assert
len
(
sortish_dl
)
==
len
(
naive_dl
)
def
_get_dataset
(
self
,
n_obs
=
1000
,
max_len
=
128
):
if
os
.
getenv
(
"USE_REAL_DATA"
,
False
):
data_dir
=
"examples/seq2seq/wmt_en_ro"
max_tokens
=
max_len
*
2
*
64
if
not
Path
(
data_dir
).
joinpath
(
"train.len"
).
exists
():
save_len_file
(
MARIAN_TINY
,
data_dir
)
else
:
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
max_tokens
=
max_len
*
4
save_len_file
(
MARIAN_TINY
,
data_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MARIAN_TINY
)
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
data_dir
,
type_path
=
"train"
,
max_source_length
=
max_len
,
max_target_length
=
max_len
,
n_obs
=
n_obs
,
)
return
ds
,
max_tokens
,
tokenizer
def
test_distributed_sortish_sampler_splits_indices_between_procs
(
self
):
ds
,
max_tokens
,
tokenizer
=
self
.
_get_dataset
()
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
parameterized
.
expand
(
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
self
,
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
examples/seq2seq/test_finetune_trainer.py
View file @
9f7b2b24
import
os
import
os
import
sys
import
sys
import
tempfile
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
TestCasePlus
,
slow
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_callback
import
TrainerState
from
transformers.trainer_utils
import
set_seed
from
transformers.trainer_utils
import
set_seed
...
@@ -15,72 +14,71 @@ set_seed(42)
...
@@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
def
test_finetune_trainer
():
class
TestFinetuneTrainer
(
TestCasePlus
):
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
def
test_finetune_trainer
(
self
):
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
output_dir
=
self
.
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
first_step_stats
=
eval_metrics
[
0
]
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
assert
"eval_bleu"
in
first_step_stats
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
def
test_finetune_trainer_slow
(
self
):
# There is a missing call to __init__process_group somewhere
output_dir
=
self
.
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
@
slow
# Check metrics
def
test_finetune_trainer_slow
():
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
# There is a missing call to __init__process_group somewhere
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"128"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
# Check metrics
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
# test if do_predict saves generations and metrics
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
# test if do_predict saves generations and metrics
def
run_trainer
(
self
,
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
contents
=
os
.
listdir
(
output_dir
)
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
output_dir
=
self
.
get_auto_remove_tmp_dir
()
assert
"test_generations.txt"
in
contents
argv
=
f
"""
assert
"test_results.json"
in
contents
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
return
output_dir
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
argv
=
f
"""
--model_name_or_path
{
model_name
}
--data_dir
{
data_dir
}
--output_dir
{
output_dir
}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_target_length
{
max_len
}
--val_max_target_length
{
max_len
}
--do_train
--do_eval
--do_predict
--num_train_epochs
{
str
(
num_train_epochs
)
}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps
{
str
(
eval_steps
)
}
--eval_steps
{
str
(
eval_steps
)
}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
"""
.
split
()
# --eval_beams 2
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
return
output_dir
examples/seq2seq/test_seq2seq_examples.py
View file @
9f7b2b24
...
@@ -3,7 +3,6 @@ import logging
...
@@ -3,7 +3,6 @@ import logging
import
os
import
os
import
sys
import
sys
import
tempfile
import
tempfile
import
unittest
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -15,11 +14,12 @@ import lightning_base
...
@@ -15,11 +14,12 @@ import lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
distillation
import
distill_main
from
distillation
import
distill_main
from
finetune
import
SummarizationModule
,
main
from
finetune
import
SummarizationModule
,
main
from
parameterized
import
parameterized
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval_search
import
run_search
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
TestCasePlus
,
require_torch_and_cuda
,
slow
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
from
utils
import
ROUGE_KEYS
,
label_smoothed_nll_loss
,
lmap
,
load_json
...
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
...
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers"
:
1
,
"student_decoder_layers"
:
1
,
"val_check_interval"
:
1.0
,
"val_check_interval"
:
1.0
,
"output_dir"
:
""
,
"output_dir"
:
""
,
"fp16"
:
False
,
# TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"fp16"
:
False
,
# TODO
(SS)
: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher"
:
False
,
"no_teacher"
:
False
,
"fp16_opt_level"
:
"O1"
,
"fp16_opt_level"
:
"O1"
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
"gpus"
:
1
if
CUDA_AVAILABLE
else
0
,
...
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
...
@@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers"
:
1
,
"student_encoder_layers"
:
1
,
"freeze_encoder"
:
False
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
"auto_scale_batch_size"
:
False
,
"overwrite_output_dir"
:
False
,
}
}
...
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
...
@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
def
make_test_data_dir
(
**
kwargs
):
def
make_test_data_dir
(
tmp_dir
):
tmp_dir
=
Path
(
tempfile
.
mkdtemp
(
**
kwargs
))
for
split
in
[
"train"
,
"val"
,
"test"
]:
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.source"
),
ARTICLES
)
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.target"
),
SUMMARIES
)
_dump_articles
(
os
.
path
.
join
(
tmp_dir
,
f
"
{
split
}
.target"
),
SUMMARIES
)
return
tmp_dir
return
tmp_dir
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
class
TestSummarizationDistiller
(
TestCase
Plus
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
...
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures
.
append
(
m
)
failures
.
append
(
m
)
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
@
require_multigpu
@
unittest
.
skip
(
"Broken at the moment"
)
def
test_multigpu
(
self
):
updates
=
dict
(
no_teacher
=
True
,
freeze_encoder
=
True
,
gpus
=
2
,
sortish_sampler
=
True
,
)
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
def
test_distill_no_teacher
(
self
):
def
test_distill_no_teacher
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
no_teacher
=
True
)
self
.
_test_distiller_cli
(
updates
)
self
.
_test_distiller_cli
(
updates
)
...
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
assertEqual
(
1
,
len
(
ckpts
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
examples
=
lmap
(
str
.
strip
,
model
.
hparams
.
data_dir
.
joinpath
(
"test.source"
).
open
().
readlines
())
examples
=
lmap
(
str
.
strip
,
Path
(
model
.
hparams
.
data_dir
)
.
joinpath
(
"test.source"
).
open
().
readlines
())
out_path
=
tempfile
.
mktemp
()
out_path
=
tempfile
.
mktemp
()
# XXX: not being cleaned up
generate_summaries_or_translations
(
examples
,
out_path
,
str
(
model
.
output_dir
/
"best_tfmr"
))
generate_summaries_or_translations
(
examples
,
out_path
,
str
(
model
.
output_dir
/
"best_tfmr"
))
self
.
assertTrue
(
Path
(
out_path
).
exists
())
self
.
assertTrue
(
Path
(
out_path
).
exists
())
out_path_new
=
tempfile
.
mkdtemp
()
out_path_new
=
self
.
get_auto_remove_tmp_dir
()
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
convert_pl_to_hf
(
ckpts
[
0
],
transformer_ckpts
[
0
].
parent
,
out_path_new
)
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
assert
os
.
path
.
exists
(
os
.
path
.
join
(
out_path_new
,
"pytorch_model.bin"
))
...
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
)
)
default_updates
.
update
(
updates
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
output_dir
=
self
.
get_auto_remove_tmp_dir
(
)
args_d
.
update
(
data_dir
=
tmp_dir
,
output_dir
=
output_dir
,
**
default_updates
)
args_d
.
update
(
data_dir
=
tmp_dir
,
output_dir
=
output_dir
,
**
default_updates
)
model
=
distill_main
(
argparse
.
Namespace
(
**
args_d
))
model
=
distill_main
(
argparse
.
Namespace
(
**
args_d
))
...
@@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
return
model
return
model
def
run_eval_tester
(
model
):
class
TestTheRest
(
TestCasePlus
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
def
run_eval_tester
(
self
,
model
):
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
assert
not
output_file_name
.
exists
()
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
assert
not
output_file_name
.
exists
()
_dump_articles
(
input_file_name
,
articles
)
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
_dump_articles
(
input_file_name
,
articles
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
f
"""
score_path
=
str
(
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"scores.json"
)
run_eval_search.py
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
{
model
}
testargs
=
f
"""
{
input_file_name
}
run_eval_search.py
{
output_file_name
}
{
model
}
--score_path
{
score_path
}
{
input_file_name
}
--task
{
task
}
{
output_file_name
}
--num_beams 2
--score_path
{
score_path
}
--length_penalty 2.0
--task
{
task
}
"""
.
split
()
--num_beams 2
--length_penalty 2.0
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
"""
.
split
()
run_generate
()
assert
Path
(
output_file_name
).
exists
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
os
.
remove
(
Path
(
output_file_name
))
run_generate
()
assert
Path
(
output_file_name
).
exists
()
# os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
# test one model to quickly (no-@slow) catch simple problems and do an
def
test_run_eval
():
# extensive testing of functionality with multiple models as @slow separately
run_eval_tester
(
T5_TINY
)
def
test_run_eval
(
self
):
self
.
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
# any extra models should go into the list here - can be slow
@
slow
@
parameterized
.
expand
([
BART_TINY
,
MBART_TINY
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
BART_TINY
,
MBART_TINY
])
@
slow
def
test_run_eval_slow
(
model
):
def
test_run_eval_slow
(
self
,
model
):
run_eval_tester
(
model
)
self
.
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
MBART_TINY
])
def
test_run_eval_search
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
text
=
{
"en"
:
[
"Machine learning is great, isn't it?"
,
"I like to eat bananas"
,
"Tomorrow is another great day!"
],
"de"
:
[
"Maschinelles Lernen ist großartig, oder?"
,
"Ich esse gerne Bananen"
,
"Morgen ist wieder ein toller Tag!"
,
],
}
tmp_dir
=
Path
(
tempfile
.
mkdtemp
())
score_path
=
str
(
tmp_dir
/
"scores.json"
)
reference_path
=
str
(
tmp_dir
/
"val.target"
)
_dump_articles
(
input_file_name
,
text
[
"en"
])
_dump_articles
(
reference_path
,
text
[
"de"
])
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
f
"""
run_eval_search.py
{
model
}
{
str
(
input_file_name
)
}
{
str
(
output_file_name
)
}
--score_path
{
score_path
}
--reference_path
{
reference_path
}
--task
{
task
}
"""
.
split
()
testargs
.
extend
([
"--search"
,
"num_beams=1:2 length_penalty=0.9:1.0"
])
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
CaptureStdout
()
as
cs
:
run_search
()
expected_strings
=
[
" num_beams | length_penalty"
,
model
,
"Best score args"
]
un_expected_strings
=
[
"Info"
]
if
"translation"
in
task
:
expected_strings
.
append
(
"bleu"
)
else
:
expected_strings
.
extend
(
ROUGE_KEYS
)
for
w
in
expected_strings
:
assert
w
in
cs
.
out
for
w
in
un_expected_strings
:
assert
w
not
in
cs
.
out
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
def
test_finetune
(
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_name_or_path
=
model
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
output_dir
=
output_dir
,
do_predict
=
True
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
assert
"n_train"
in
args_d
args
=
argparse
.
Namespace
(
**
args_d
)
module
=
main
(
args
)
input_embeds
=
module
.
model
.
get_input_embeddings
()
assert
not
input_embeds
.
weight
.
requires_grad
if
model
==
T5_TINY
:
lm_head
=
module
.
model
.
lm_head
assert
not
lm_head
.
weight
.
requires_grad
assert
(
lm_head
.
weight
==
input_embeds
.
weight
).
all
().
item
()
elif
model
==
FSMT_TINY
:
fsmt
=
module
.
model
.
model
embed_pos
=
fsmt
.
decoder
.
embed_positions
assert
not
embed_pos
.
weight
.
requires_grad
assert
not
fsmt
.
decoder
.
embed_tokens
.
weight
.
requires_grad
# check that embeds are not the same
assert
fsmt
.
decoder
.
embed_tokens
!=
fsmt
.
encoder
.
embed_tokens
else
:
bart
=
module
.
model
.
model
embed_pos
=
bart
.
decoder
.
embed_positions
assert
not
embed_pos
.
weight
.
requires_grad
assert
not
bart
.
shared
.
weight
.
requires_grad
# check that embeds are the same
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
example_batch
=
load_json
(
module
.
output_dir
/
"text_batch.json"
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
()
args_d
.
update
(
data_dir
=
tmp_dir
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
do_predict
=
False
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
# test models whose config includes the extra_model_args
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
model
=
BART_TINY
@
parameterized
.
expand
([
T5_TINY
,
MBART_TINY
])
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
@
slow
args_d1
=
args_d
.
copy
()
def
test_run_eval_search
(
self
,
model
):
args_d1
.
update
(
input_file_name
=
Path
(
self
.
get_auto_remove_tmp_dir
())
/
"utest_input.source"
model_name_or_path
=
model
,
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
output_dir
=
output_dir
,
assert
not
output_file_name
.
exists
()
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
text
=
{
for
p
in
extra_model_params
:
"en"
:
[
"Machine learning is great, isn't it?"
,
"I like to eat bananas"
,
"Tomorrow is another great day!"
],
args_d1
[
p
]
=
0.5
"de"
:
[
args
=
argparse
.
Namespace
(
**
args_d1
)
"Maschinelles Lernen ist großartig, oder?"
,
model
=
main
(
args
)
"Ich esse gerne Bananen"
,
for
p
in
extra_model_params
:
"Morgen ist wieder ein toller Tag!"
,
assert
getattr
(
model
.
config
,
p
)
==
0.5
,
f
"failed to override the model config for param
{
p
}
"
],
}
# test models whose config doesn't include the extra_model_args
model
=
T5_TINY
tmp_dir
=
Path
(
self
.
get_auto_remove_tmp_dir
())
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_2_"
)
score_path
=
str
(
tmp_dir
/
"scores.json"
)
args_d2
=
args_d
.
copy
()
reference_path
=
str
(
tmp_dir
/
"val.target"
)
args_d2
.
update
(
_dump_articles
(
input_file_name
,
text
[
"en"
])
model_name_or_path
=
model
,
_dump_articles
(
reference_path
,
text
[
"de"
])
output_dir
=
output_dir
,
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
f
"""
run_eval_search.py
{
model
}
{
str
(
input_file_name
)
}
{
str
(
output_file_name
)
}
--score_path
{
score_path
}
--reference_path
{
reference_path
}
--task
{
task
}
"""
.
split
()
testargs
.
extend
([
"--search"
,
"num_beams=1:2 length_penalty=0.9:1.0"
])
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
CaptureStdout
()
as
cs
:
run_search
()
expected_strings
=
[
" num_beams | length_penalty"
,
model
,
"Best score args"
]
un_expected_strings
=
[
"Info"
]
if
"translation"
in
task
:
expected_strings
.
append
(
"bleu"
)
else
:
expected_strings
.
extend
(
ROUGE_KEYS
)
for
w
in
expected_strings
:
assert
w
in
cs
.
out
for
w
in
un_expected_strings
:
assert
w
not
in
cs
.
out
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
@
parameterized
.
expand
(
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
)
unsupported_param
=
"encoder_layerdrop"
def
test_finetune
(
self
,
model
):
args_d2
[
unsupported_param
]
=
0.5
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args
=
argparse
.
Namespace
(
**
args_d2
)
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
with
pytest
.
raises
(
Exception
)
as
excinfo
:
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
output_dir
=
self
.
get_auto_remove_tmp_dir
()
args_d
.
update
(
data_dir
=
tmp_dir
,
model_name_or_path
=
model
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
output_dir
=
output_dir
,
do_predict
=
True
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
assert
"n_train"
in
args_d
args
=
argparse
.
Namespace
(
**
args_d
)
module
=
main
(
args
)
input_embeds
=
module
.
model
.
get_input_embeddings
()
assert
not
input_embeds
.
weight
.
requires_grad
if
model
==
T5_TINY
:
lm_head
=
module
.
model
.
lm_head
assert
not
lm_head
.
weight
.
requires_grad
assert
(
lm_head
.
weight
==
input_embeds
.
weight
).
all
().
item
()
elif
model
==
FSMT_TINY
:
fsmt
=
module
.
model
.
model
embed_pos
=
fsmt
.
decoder
.
embed_positions
assert
not
embed_pos
.
weight
.
requires_grad
assert
not
fsmt
.
decoder
.
embed_tokens
.
weight
.
requires_grad
# check that embeds are not the same
assert
fsmt
.
decoder
.
embed_tokens
!=
fsmt
.
encoder
.
embed_tokens
else
:
bart
=
module
.
model
.
model
embed_pos
=
bart
.
decoder
.
embed_positions
assert
not
embed_pos
.
weight
.
requires_grad
assert
not
bart
.
shared
.
weight
.
requires_grad
# check that embeds are the same
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
example_batch
=
load_json
(
module
.
output_dir
/
"text_batch.json"
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
(
self
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"summarization"
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
args_d
.
update
(
data_dir
=
tmp_dir
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
do_predict
=
False
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
# test models whose config includes the extra_model_args
model
=
BART_TINY
output_dir
=
self
.
get_auto_remove_tmp_dir
()
args_d1
=
args_d
.
copy
()
args_d1
.
update
(
model_name_or_path
=
model
,
output_dir
=
output_dir
,
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
args_d1
[
p
]
=
0.5
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
model
=
main
(
args
)
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
for
p
in
extra_model_params
:
assert
getattr
(
model
.
config
,
p
)
==
0.5
,
f
"failed to override the model config for param
{
p
}
"
def
test_finetune_lr_schedulers
():
# test models whose config doesn't include the extra_model_args
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
model
=
T5_TINY
output_dir
=
self
.
get_auto_remove_tmp_dir
()
task
=
"summarization"
args_d2
=
args_d
.
copy
()
tmp_dir
=
make_test_data_dir
()
args_d2
.
update
(
model_name_or_path
=
model
,
model
=
BART_TINY
output_dir
=
output_dir
,
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_1_"
)
)
unsupported_param
=
"encoder_layerdrop"
args_d
.
update
(
args_d2
[
unsupported_param
]
=
0.5
data_dir
=
tmp_dir
,
args
=
argparse
.
Namespace
(
**
args_d2
)
model_name_or_path
=
model
,
with
pytest
.
raises
(
Exception
)
as
excinfo
:
output_dir
=
output_dir
,
model
=
main
(
args
)
tokenizer_name
=
None
,
assert
str
(
excinfo
.
value
)
==
f
"model config doesn't have a `
{
unsupported_param
}
` attribute"
train_batch_size
=
2
,
eval_batch_size
=
2
,
def
test_finetune_lr_schedulers
(
self
):
do_predict
=
False
,
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
task
,
src_lang
=
"en_XX"
,
task
=
"summarization"
tgt_lang
=
"ro_RO"
,
tmp_dir
=
make_test_data_dir
(
tmp_dir
=
self
.
get_auto_remove_tmp_dir
())
freeze_encoder
=
True
,
freeze_embeds
=
True
,
model
=
BART_TINY
)
output_dir
=
self
.
get_auto_remove_tmp_dir
()
args_d
.
update
(
data_dir
=
tmp_dir
,
model_name_or_path
=
model
,
output_dir
=
output_dir
,
tokenizer_name
=
None
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
do_predict
=
False
,
task
=
task
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
freeze_encoder
=
True
,
freeze_embeds
=
True
,
)
# emulate finetune.py
# emulate finetune.py
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
{
"--help"
:
True
}
args
=
{
"--help"
:
True
}
# --help test
# --help test
with
pytest
.
raises
(
SystemExit
)
as
excinfo
:
with
pytest
.
raises
(
SystemExit
)
as
excinfo
:
with
CaptureStdout
()
as
cs
:
with
CaptureStdout
()
as
cs
:
args
=
parser
.
parse_args
(
args
)
args
=
parser
.
parse_args
(
args
)
assert
False
,
"--help is expected to sys.exit"
assert
False
,
"--help is expected to sys.exit"
assert
excinfo
.
type
==
SystemExit
assert
excinfo
.
type
==
SystemExit
expected
=
lightning_base
.
arg_to_scheduler_metavar
expected
=
lightning_base
.
arg_to_scheduler_metavar
assert
expected
in
cs
.
out
,
"--help is expected to list the supported schedulers"
assert
expected
in
cs
.
out
,
"--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test
# --lr_scheduler=non_existing_scheduler test
unsupported_param
=
"non_existing_scheduler"
unsupported_param
=
"non_existing_scheduler"
args
=
{
f
"--lr_scheduler=
{
unsupported_param
}
"
}
args
=
{
f
"--lr_scheduler=
{
unsupported_param
}
"
}
with
pytest
.
raises
(
SystemExit
)
as
excinfo
:
with
pytest
.
raises
(
SystemExit
)
as
excinfo
:
with
CaptureStderr
()
as
cs
:
with
CaptureStderr
()
as
cs
:
args
=
parser
.
parse_args
(
args
)
args
=
parser
.
parse_args
(
args
)
assert
False
,
"invalid argument is expected to sys.exit"
assert
False
,
"invalid argument is expected to sys.exit"
assert
excinfo
.
type
==
SystemExit
assert
excinfo
.
type
==
SystemExit
expected
=
f
"invalid choice: '
{
unsupported_param
}
'"
expected
=
f
"invalid choice: '
{
unsupported_param
}
'"
assert
expected
in
cs
.
err
,
f
"should have bailed on invalid choice of scheduler
{
unsupported_param
}
"
assert
expected
in
cs
.
err
,
f
"should have bailed on invalid choice of scheduler
{
unsupported_param
}
"
# --lr_scheduler=existing_scheduler test
# --lr_scheduler=existing_scheduler test
supported_param
=
"cosine"
supported_param
=
"cosine"
args_d1
=
args_d
.
copy
()
args_d1
=
args_d
.
copy
()
args_d1
[
"lr_scheduler"
]
=
supported_param
args_d1
[
"lr_scheduler"
]
=
supported_param
args
=
argparse
.
Namespace
(
**
args_d1
)
args
=
argparse
.
Namespace
(
**
args_d1
)
model
=
main
(
args
)
model
=
main
(
args
)
assert
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
,
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
assert
(
getattr
(
model
.
hparams
,
"lr_scheduler"
)
==
supported_param
),
f
"lr_scheduler=
{
supported_param
}
shouldn't fail"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment