Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48f23f92
Unverified
Commit
48f23f92
authored
Oct 01, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 01, 2020
Browse files
[s2sTrainer] test + code cleanup (#7467)
parent
097049b8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
114 deletions
+100
-114
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+2
-18
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+24
-48
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+16
-19
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+39
-29
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+19
-0
No files found.
examples/seq2seq/finetune.py
View file @
48f23f92
...
@@ -26,6 +26,7 @@ from utils import (
...
@@ -26,6 +26,7 @@ from utils import (
calculate_bleu
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
flatten_list
,
flatten_list
,
freeze_embeds
,
freeze_params
,
freeze_params
,
get_git_info
,
get_git_info
,
label_smoothed_nll_loss
,
label_smoothed_nll_loss
,
...
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
if
self
.
hparams
.
freeze_embeds
:
if
self
.
hparams
.
freeze_embeds
:
self
.
freeze_embeds
()
freeze_embeds
(
self
.
model
)
if
self
.
hparams
.
freeze_encoder
:
if
self
.
hparams
.
freeze_encoder
:
freeze_params
(
self
.
model
.
get_encoder
())
freeze_params
(
self
.
model
.
get_encoder
())
assert_all_frozen
(
self
.
model
.
get_encoder
())
assert_all_frozen
(
self
.
model
.
get_encoder
())
...
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
...
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
else
:
else
:
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if
self
.
model_type
==
"t5"
:
freeze_params
(
self
.
model
.
shared
)
for
d
in
[
self
.
model
.
encoder
,
self
.
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
elif
self
.
model_type
==
"fsmt"
:
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
else
:
freeze_params
(
self
.
model
.
model
.
shared
)
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
def
forward
(
self
,
input_ids
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
**
kwargs
)
return
self
.
model
(
input_ids
,
**
kwargs
)
...
...
examples/seq2seq/finetune_trainer.py
View file @
48f23f92
import
json
import
logging
import
logging
import
os
import
os
import
sys
import
sys
...
@@ -29,10 +28,13 @@ from utils import (
...
@@ -29,10 +28,13 @@ from utils import (
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
freeze_embeds
,
freeze_params
,
freeze_params
,
lmap
,
lmap
,
save_json
,
trim_batch
,
trim_batch
,
use_task_specific_params
,
use_task_specific_params
,
write_txt_file
,
)
)
...
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
...
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
self
.
pad_token_id
is
not
None
,
"self.pad_token_id must be defined"
self
.
data_args
=
data_args
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
...
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
...
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
labels
=
labels
else
:
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
labels
=
labels
batch
=
{
batch
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
...
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
...
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
return
batch
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
def
_shift_right_t5
(
self
,
input_ids
):
decoder_start_token_id
=
self
.
pad_token_id
assert
(
decoder_start_token_id
is
not
None
),
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
decoder_start_token_id
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
@@ -267,17 +261,15 @@ def main():
...
@@ -267,17 +261,15 @@ def main():
use_task_specific_params
(
model
,
data_args
.
task
)
use_task_specific_params
(
model
,
data_args
.
task
)
# set num_beams for evaluation
# set num_beams for evaluation
if
data_args
.
eval_beams
is
not
None
:
if
data_args
.
eval_beams
is
None
:
model
.
config
.
num_beams
=
data_args
.
eval_beams
data_args
.
eval_beams
=
model
.
config
.
num_beams
assert
model
.
config
.
num_beams
>=
1
,
f
"got eval_beams=
{
model
.
config
.
num_beams
}
. Need an integer >= 1"
# set max length for generation
model
.
config
.
max_generate_length
=
data_args
.
val_max_target_length
# set decoder_start_token_id for MBart
# set decoder_start_token_id for MBart
if
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
tokenizer
,
MBartTokenizer
):
if
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
tokenizer
,
MBartTokenizer
):
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
assert
(
model
.
config
.
decoder_start_token_id
=
decoder_start_token_id
data_args
.
tgt_lang
is
not
None
and
data_args
.
src_lang
is
not
None
),
"mBart requires --tgt_lang and --src_lang"
model
.
config
.
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
...
@@ -293,32 +285,20 @@ def main():
...
@@ -293,32 +285,20 @@ def main():
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
,
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
,
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
return
compute_metrics_fn
def
freeze_embeds
(
model
:
torch
.
nn
.
Module
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try
:
freeze_params
(
model
.
model
.
shared
)
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
except
AttributeError
:
freeze_params
(
model
.
shared
)
for
d
in
[
model
.
encoder
,
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
if
model_args
.
freeze_embeds
:
if
model_args
.
freeze_embeds
:
freeze_embeds
(
model
)
freeze_embeds
(
model
)
if
model_args
.
freeze_encoder
:
if
model_args
.
freeze_encoder
:
...
@@ -376,6 +356,7 @@ def main():
...
@@ -376,6 +356,7 @@ def main():
eval_dataset
=
eval_dataset
,
eval_dataset
=
eval_dataset
,
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
compute_metrics
=
build_compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
compute_metrics
=
build_compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
data_args
=
data_args
,
)
)
# Training
# Training
...
@@ -396,41 +377,36 @@ def main():
...
@@ -396,41 +377,36 @@ def main():
result
=
trainer
.
evaluate
()
result
=
trainer
.
evaluate
()
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results.json"
)
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
for
key
,
value
in
result
.
items
():
for
key
,
value
in
result
.
items
():
logger
.
info
(
" %s = %s"
,
key
,
value
)
logger
.
info
(
" %s = %s"
,
key
,
value
)
save_json
(
result
,
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results.json"
))
with
open
(
output_eval_file
,
"w"
)
as
f
:
json
.
dump
(
result
,
f
)
eval_results
.
update
(
result
)
eval_results
.
update
(
result
)
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
logging
.
info
(
"*** Test ***"
)
logging
.
info
(
"*** Test ***"
)
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
)
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
)
test_metrics
=
test_output
.
metrics
test_metrics
=
{
k
.
replace
(
"eval"
,
"test"
):
v
for
k
,
v
in
test_output
.
metrics
.
items
()}
test_metrics
=
{
k
.
replace
(
"eval"
,
"test"
):
v
for
k
,
v
in
test_metrics
.
items
()}
output_test_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"test_results.json"
)
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
logger
.
info
(
"***** Test results *****"
)
logger
.
info
(
"***** Test results *****"
)
for
key
,
value
in
test_metrics
.
items
():
for
key
,
value
in
test_metrics
.
items
():
logger
.
info
(
" %s = %s"
,
key
,
value
)
logger
.
info
(
" %s = %s"
,
key
,
value
)
with
open
(
output_test_file
,
"w"
)
as
f
:
save_json
(
test_metrics
,
os
.
path
.
join
(
training_args
.
output_dir
,
"test_results.json"
))
json
.
dump
(
test_metrics
,
f
)
eval_results
.
update
(
test_metrics
)
if
training_args
.
predict_with_generate
:
if
training_args
.
predict_with_generate
:
test_preds
=
tokenizer
.
batch_decode
(
test_output
.
predictions
,
skip_special_tokens
=
True
)
test_preds
=
tokenizer
.
batch_decode
(
test_output
.
predictions
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
test_preds
=
lmap
(
str
.
strip
,
test_preds
)
test_preds
=
lmap
(
str
.
strip
,
test_preds
)
output_test_pred_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"test_generations.txt"
)
write_txt_file
(
test_preds
,
os
.
path
.
join
(
training_args
.
output_dir
,
"test_generations.txt"
))
with
open
(
output_test_pred_file
,
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
test_preds
))
if
trainer
.
is_world_process_zero
():
save_json
(
eval_results
,
"all_results.json"
)
return
eval_results
return
eval_results
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
48f23f92
...
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
...
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
class
Seq2SeqTrainer
(
Trainer
):
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
data_args
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
data_args
=
data_args
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
pad_token_id
=
self
.
model
.
config
.
pad_token_id
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
return
None
return
None
...
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
...
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
labels
=
inputs
.
pop
(
"labels"
)
labels
=
inputs
.
pop
(
"labels"
)
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
logits
=
outputs
[
0
]
logits
=
outputs
[
0
]
return
self
.
_compute_loss
(
logits
,
labels
,
ignore_index
=
model
.
config
.
pad_token_id
)
return
self
.
_compute_loss
(
logits
,
labels
,
ignore_index
=
self
.
pad_token_id
)
def
_compute_loss
(
self
,
logits
,
labels
,
ignore_index
):
def
_compute_loss
(
self
,
logits
,
labels
,
ignore_index
):
if
self
.
args
.
label_smoothing
==
0
:
if
self
.
args
.
label_smoothing
==
0
:
...
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
...
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
"""
"""
inputs
=
self
.
_prepare_inputs
(
inputs
)
inputs
=
self
.
_prepare_inputs
(
inputs
)
max_length
=
(
model
.
config
.
max_generate_length
if
hasattr
(
model
.
config
,
"max_generate_length"
)
else
model
.
config
.
max_position_embeddings
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
args
.
predict_with_generate
and
not
self
.
args
.
prediction_loss_only
:
if
self
.
args
.
predict_with_generate
and
not
self
.
args
.
prediction_loss_only
:
generated_tokens
=
model
.
generate
(
generated_tokens
=
model
.
generate
(
inputs
[
"input_ids"
],
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
attention_mask
=
inputs
[
"attention_mask"
],
use_cache
=
True
,
use_cache
=
True
,
num_beams
=
model
.
config
.
num
_beams
,
num_beams
=
self
.
data_args
.
eval
_beams
,
max_length
=
max
_length
,
max_length
=
self
.
max_gen
_length
,
)
)
# in case the batch is shorter than max length, the output should be padded
# in case the batch is shorter than max length, the output should be padded
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
max_length
,
model
.
config
.
pad_token_id
generated_tokens
,
self
.
max_gen_length
,
self
.
pad_token_id
)
)
labels_out
=
inputs
.
get
(
"labels"
)
labels_out
=
inputs
.
get
(
"labels"
)
outputs
=
model
(
**
inputs
)
# Call forward again to get loss # TODO: avoidable?
logits
=
outputs
[
1
]
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
loss
=
self
.
_compute_loss
(
logits
,
labels_out
,
model
.
config
.
pad_token_id
)
loss
=
self
.
_compute_loss
(
outputs
[
1
]
,
labels_out
,
self
.
pad_token_id
)
loss
=
loss
.
mean
().
item
()
loss
=
loss
.
mean
().
item
()
if
self
.
args
.
prediction_loss_only
:
logits
=
None
else
:
logits
=
generated_tokens
if
self
.
args
.
predict_with_generate
else
logits
if
self
.
args
.
prediction_loss_only
:
if
self
.
args
.
prediction_loss_only
:
return
(
loss
,
None
,
None
)
return
(
loss
,
None
,
None
)
logits
=
generated_tokens
if
self
.
args
.
predict_with_generate
else
outputs
[
1
]
labels_out
=
labels_out
.
detach
()
labels_out
=
labels_out
.
detach
()
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
max_length
,
model
.
config
.
pad_token_id
)
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
self
.
max_gen_length
,
self
.
pad_token_id
)
return
(
loss
,
logits
.
detach
(),
labels
)
return
(
loss
,
logits
.
detach
(),
labels
)
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
,
pad_token_id
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
,
pad_token_id
):
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
48f23f92
...
@@ -3,36 +3,54 @@ import sys
...
@@ -3,36 +3,54 @@ import sys
import
tempfile
import
tempfile
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
slow
from
transformers.trainer_utils
import
set_seed
from
.finetune_trainer
import
main
from
.finetune_trainer
import
main
from
.test_seq2seq_examples
import
MBART_TINY
from
.test_seq2seq_examples
import
MBART_TINY
from
.utils
import
load_json
from
.utils
import
load_json
MODEL_NAME
=
MBART_TINY
set_seed
(
42
)
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
def
test_finetune_trainer
():
def
test_model_download
():
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
@
slow
def
test_finetune_trainer
():
def
test_finetune_trainer_slow
():
# TODO(SS): This will fail on devices with more than 1 GPU.
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"32"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
max_len
=
"128"
num_train_epochs
=
4
eval_steps
=
2
argv
=
[
argv
=
[
"--model_name_or_path"
,
"--model_name_or_path"
,
MARIAN_MODEL
,
model_name
,
"--data_dir"
,
"--data_dir"
,
data_dir
,
data_dir
,
"--output_dir"
,
"--output_dir"
,
...
@@ -72,25 +90,17 @@ def test_finetune_trainer():
...
@@ -72,25 +90,17 @@ def test_finetune_trainer():
"--sortish_sampler"
,
"--sortish_sampler"
,
"--label_smoothing"
,
"--label_smoothing"
,
"0.1"
,
"0.1"
,
# "--eval_beams",
# "2",
"--task"
,
"--task"
,
"translation"
,
"translation"
,
"--tgt_lang"
,
"ro_RO"
,
"--src_lang"
,
"en_XX"
,
]
]
testargs
=
[
"finetune_trainer.py"
]
+
argv
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
main
()
# Check metrics
return
output_dir
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
examples/seq2seq/utils.py
View file @
48f23f92
...
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
...
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
par
.
requires_grad
=
False
par
.
requires_grad
=
False
def
freeze_embeds
(
model
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type
=
model
.
config
.
model_type
if
model_type
==
"t5"
:
freeze_params
(
model
.
shared
)
for
d
in
[
model
.
encoder
,
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
elif
model_type
==
"fsmt"
:
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
else
:
freeze_params
(
model
.
model
.
shared
)
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
def
grad_status
(
model
:
nn
.
Module
)
->
Iterable
:
def
grad_status
(
model
:
nn
.
Module
)
->
Iterable
:
return
(
par
.
requires_grad
for
par
in
model
.
parameters
())
return
(
par
.
requires_grad
for
par
in
model
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment