Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48f23f92
Unverified
Commit
48f23f92
authored
Oct 01, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 01, 2020
Browse files
[s2sTrainer] test + code cleanup (#7467)
parent
097049b8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
114 deletions
+100
-114
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+2
-18
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+24
-48
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+16
-19
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+39
-29
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+19
-0
No files found.
examples/seq2seq/finetune.py
View file @
48f23f92
...
@@ -26,6 +26,7 @@ from utils import (
...
@@ -26,6 +26,7 @@ from utils import (
calculate_bleu
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
flatten_list
,
flatten_list
,
freeze_embeds
,
freeze_params
,
freeze_params
,
get_git_info
,
get_git_info
,
label_smoothed_nll_loss
,
label_smoothed_nll_loss
,
...
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
if
self
.
hparams
.
freeze_embeds
:
if
self
.
hparams
.
freeze_embeds
:
self
.
freeze_embeds
()
freeze_embeds
(
self
.
model
)
if
self
.
hparams
.
freeze_encoder
:
if
self
.
hparams
.
freeze_encoder
:
freeze_params
(
self
.
model
.
get_encoder
())
freeze_params
(
self
.
model
.
get_encoder
())
assert_all_frozen
(
self
.
model
.
get_encoder
())
assert_all_frozen
(
self
.
model
.
get_encoder
())
...
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
...
@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
else
:
else
:
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if
self
.
model_type
==
"t5"
:
freeze_params
(
self
.
model
.
shared
)
for
d
in
[
self
.
model
.
encoder
,
self
.
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
elif
self
.
model_type
==
"fsmt"
:
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
else
:
freeze_params
(
self
.
model
.
model
.
shared
)
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
def
forward
(
self
,
input_ids
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
**
kwargs
)
return
self
.
model
(
input_ids
,
**
kwargs
)
...
...
examples/seq2seq/finetune_trainer.py
View file @
48f23f92
import
json
import
logging
import
logging
import
os
import
os
import
sys
import
sys
...
@@ -29,10 +28,13 @@ from utils import (
...
@@ -29,10 +28,13 @@ from utils import (
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
freeze_embeds
,
freeze_params
,
freeze_params
,
lmap
,
lmap
,
save_json
,
trim_batch
,
trim_batch
,
use_task_specific_params
,
use_task_specific_params
,
write_txt_file
,
)
)
...
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
...
@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
self
.
pad_token_id
is
not
None
,
"self.pad_token_id must be defined"
self
.
data_args
=
data_args
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
...
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
...
@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
labels
=
labels
else
:
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
labels
=
labels
batch
=
{
batch
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
...
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
...
@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
return
batch
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
def
_shift_right_t5
(
self
,
input_ids
):
decoder_start_token_id
=
self
.
pad_token_id
assert
(
decoder_start_token_id
is
not
None
),
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
decoder_start_token_id
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
@@ -267,17 +261,15 @@ def main():
...
@@ -267,17 +261,15 @@ def main():
use_task_specific_params
(
model
,
data_args
.
task
)
use_task_specific_params
(
model
,
data_args
.
task
)
# set num_beams for evaluation
# set num_beams for evaluation
if
data_args
.
eval_beams
is
not
None
:
if
data_args
.
eval_beams
is
None
:
model
.
config
.
num_beams
=
data_args
.
eval_beams
data_args
.
eval_beams
=
model
.
config
.
num_beams
assert
model
.
config
.
num_beams
>=
1
,
f
"got eval_beams=
{
model
.
config
.
num_beams
}
. Need an integer >= 1"
# set max length for generation
model
.
config
.
max_generate_length
=
data_args
.
val_max_target_length
# set decoder_start_token_id for MBart
# set decoder_start_token_id for MBart
if
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
tokenizer
,
MBartTokenizer
):
if
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
tokenizer
,
MBartTokenizer
):
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
assert
(
model
.
config
.
decoder_start_token_id
=
decoder_start_token_id
data_args
.
tgt_lang
is
not
None
and
data_args
.
src_lang
is
not
None
),
"mBart requires --tgt_lang and --src_lang"
model
.
config
.
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
...
@@ -293,32 +285,20 @@ def main():
...
@@ -293,32 +285,20 @@ def main():
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
,
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
))
,
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
return
compute_metrics_fn
def
freeze_embeds
(
model
:
torch
.
nn
.
Module
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try
:
freeze_params
(
model
.
model
.
shared
)
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
except
AttributeError
:
freeze_params
(
model
.
shared
)
for
d
in
[
model
.
encoder
,
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
if
model_args
.
freeze_embeds
:
if
model_args
.
freeze_embeds
:
freeze_embeds
(
model
)
freeze_embeds
(
model
)
if
model_args
.
freeze_encoder
:
if
model_args
.
freeze_encoder
:
...
@@ -376,6 +356,7 @@ def main():
...
@@ -376,6 +356,7 @@ def main():
eval_dataset
=
eval_dataset
,
eval_dataset
=
eval_dataset
,
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
compute_metrics
=
build_compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
compute_metrics
=
build_compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
data_args
=
data_args
,
)
)
# Training
# Training
...
@@ -396,41 +377,36 @@ def main():
...
@@ -396,41 +377,36 @@ def main():
result
=
trainer
.
evaluate
()
result
=
trainer
.
evaluate
()
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results.json"
)
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results *****"
)
for
key
,
value
in
result
.
items
():
for
key
,
value
in
result
.
items
():
logger
.
info
(
" %s = %s"
,
key
,
value
)
logger
.
info
(
" %s = %s"
,
key
,
value
)
save_json
(
result
,
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results.json"
))
with
open
(
output_eval_file
,
"w"
)
as
f
:
json
.
dump
(
result
,
f
)
eval_results
.
update
(
result
)
eval_results
.
update
(
result
)
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
logging
.
info
(
"*** Test ***"
)
logging
.
info
(
"*** Test ***"
)
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
)
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
)
test_metrics
=
test_output
.
metrics
test_metrics
=
{
k
.
replace
(
"eval"
,
"test"
):
v
for
k
,
v
in
test_output
.
metrics
.
items
()}
test_metrics
=
{
k
.
replace
(
"eval"
,
"test"
):
v
for
k
,
v
in
test_metrics
.
items
()}
output_test_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"test_results.json"
)
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
logger
.
info
(
"***** Test results *****"
)
logger
.
info
(
"***** Test results *****"
)
for
key
,
value
in
test_metrics
.
items
():
for
key
,
value
in
test_metrics
.
items
():
logger
.
info
(
" %s = %s"
,
key
,
value
)
logger
.
info
(
" %s = %s"
,
key
,
value
)
with
open
(
output_test_file
,
"w"
)
as
f
:
save_json
(
test_metrics
,
os
.
path
.
join
(
training_args
.
output_dir
,
"test_results.json"
))
json
.
dump
(
test_metrics
,
f
)
eval_results
.
update
(
test_metrics
)
if
training_args
.
predict_with_generate
:
if
training_args
.
predict_with_generate
:
test_preds
=
tokenizer
.
batch_decode
(
test_output
.
predictions
,
skip_special_tokens
=
True
)
test_preds
=
tokenizer
.
batch_decode
(
test_output
.
predictions
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
test_preds
=
lmap
(
str
.
strip
,
test_preds
)
test_preds
=
lmap
(
str
.
strip
,
test_preds
)
output_test_pred_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"test_generations.txt"
)
write_txt_file
(
test_preds
,
os
.
path
.
join
(
training_args
.
output_dir
,
"test_generations.txt"
))
with
open
(
output_test_pred_file
,
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
test_preds
))
if
trainer
.
is_world_process_zero
():
save_json
(
eval_results
,
"all_results.json"
)
return
eval_results
return
eval_results
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
48f23f92
...
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
...
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
class
Seq2SeqTrainer
(
Trainer
):
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
data_args
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
data_args
=
data_args
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
pad_token_id
=
self
.
model
.
config
.
pad_token_id
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
return
None
return
None
...
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
...
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
labels
=
inputs
.
pop
(
"labels"
)
labels
=
inputs
.
pop
(
"labels"
)
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
logits
=
outputs
[
0
]
logits
=
outputs
[
0
]
return
self
.
_compute_loss
(
logits
,
labels
,
ignore_index
=
model
.
config
.
pad_token_id
)
return
self
.
_compute_loss
(
logits
,
labels
,
ignore_index
=
self
.
pad_token_id
)
def
_compute_loss
(
self
,
logits
,
labels
,
ignore_index
):
def
_compute_loss
(
self
,
logits
,
labels
,
ignore_index
):
if
self
.
args
.
label_smoothing
==
0
:
if
self
.
args
.
label_smoothing
==
0
:
...
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
...
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
"""
"""
inputs
=
self
.
_prepare_inputs
(
inputs
)
inputs
=
self
.
_prepare_inputs
(
inputs
)
max_length
=
(
model
.
config
.
max_generate_length
if
hasattr
(
model
.
config
,
"max_generate_length"
)
else
model
.
config
.
max_position_embeddings
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
args
.
predict_with_generate
and
not
self
.
args
.
prediction_loss_only
:
if
self
.
args
.
predict_with_generate
and
not
self
.
args
.
prediction_loss_only
:
generated_tokens
=
model
.
generate
(
generated_tokens
=
model
.
generate
(
inputs
[
"input_ids"
],
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
attention_mask
=
inputs
[
"attention_mask"
],
use_cache
=
True
,
use_cache
=
True
,
num_beams
=
model
.
config
.
num
_beams
,
num_beams
=
self
.
data_args
.
eval
_beams
,
max_length
=
max
_length
,
max_length
=
self
.
max_gen
_length
,
)
)
# in case the batch is shorter than max length, the output should be padded
# in case the batch is shorter than max length, the output should be padded
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
max_length
,
model
.
config
.
pad_token_id
generated_tokens
,
self
.
max_gen_length
,
self
.
pad_token_id
)
)
labels_out
=
inputs
.
get
(
"labels"
)
labels_out
=
inputs
.
get
(
"labels"
)
outputs
=
model
(
**
inputs
)
# Call forward again to get loss # TODO: avoidable?
logits
=
outputs
[
1
]
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
loss
=
self
.
_compute_loss
(
logits
,
labels_out
,
model
.
config
.
pad_token_id
)
loss
=
self
.
_compute_loss
(
outputs
[
1
]
,
labels_out
,
self
.
pad_token_id
)
loss
=
loss
.
mean
().
item
()
loss
=
loss
.
mean
().
item
()
if
self
.
args
.
prediction_loss_only
:
if
self
.
args
.
prediction_loss_only
:
logits
=
None
return
(
loss
,
None
,
None
)
else
:
logits
=
generated_tokens
if
self
.
args
.
predict_with_generate
else
logits
if
self
.
args
.
prediction_loss_only
:
logits
=
generated_tokens
if
self
.
args
.
predict_with_generate
else
outputs
[
1
]
return
(
loss
,
None
,
None
)
labels_out
=
labels_out
.
detach
()
labels_out
=
labels_out
.
detach
()
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
max_length
,
model
.
config
.
pad_token_id
)
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
self
.
max_gen_length
,
self
.
pad_token_id
)
return
(
loss
,
logits
.
detach
(),
labels
)
return
(
loss
,
logits
.
detach
(),
labels
)
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
,
pad_token_id
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
,
pad_token_id
):
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
48f23f92
...
@@ -3,36 +3,54 @@ import sys
...
@@ -3,36 +3,54 @@ import sys
import
tempfile
import
tempfile
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
slow
from
transformers.trainer_utils
import
set_seed
from
.finetune_trainer
import
main
from
.finetune_trainer
import
main
from
.test_seq2seq_examples
import
MBART_TINY
from
.test_seq2seq_examples
import
MBART_TINY
from
.utils
import
load_json
from
.utils
import
load_json
MODEL_NAME
=
MBART_TINY
set_seed
(
42
)
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL
=
"sshleifer/student_marian_en_ro_6_1"
@
slow
def
test_finetune_trainer
():
def
test_model_download
():
output_dir
=
run_trainer
(
1
,
"12"
,
MBART_TINY
,
1
)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
BartForConditionalGeneration
.
from_pretrained
(
MODEL_NAME
)
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
MarianMTModel
.
from_pretrained
(
MARIAN_MODEL
)
first_step_stats
=
eval_metrics
[
0
]
assert
"eval_bleu"
in
first_step_stats
@
slow
@
slow
def
test_finetune_trainer
():
def
test_finetune_trainer_slow
():
# TODO(SS): This will fail on devices with more than 1 GPU.
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"32"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
def
run_trainer
(
eval_steps
:
int
,
max_len
:
str
,
model_name
:
str
,
num_train_epochs
:
int
):
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
data_dir
=
"examples/seq2seq/test_data/wmt_en_ro"
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"marian_output"
)
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"test_output"
)
max_len
=
"128"
num_train_epochs
=
4
eval_steps
=
2
argv
=
[
argv
=
[
"--model_name_or_path"
,
"--model_name_or_path"
,
MARIAN_MODEL
,
model_name
,
"--data_dir"
,
"--data_dir"
,
data_dir
,
data_dir
,
"--output_dir"
,
"--output_dir"
,
...
@@ -72,25 +90,17 @@ def test_finetune_trainer():
...
@@ -72,25 +90,17 @@ def test_finetune_trainer():
"--sortish_sampler"
,
"--sortish_sampler"
,
"--label_smoothing"
,
"--label_smoothing"
,
"0.1"
,
"0.1"
,
# "--eval_beams",
# "2",
"--task"
,
"--task"
,
"translation"
,
"translation"
,
"--tgt_lang"
,
"ro_RO"
,
"--src_lang"
,
"en_XX"
,
]
]
testargs
=
[
"finetune_trainer.py"
]
+
argv
testargs
=
[
"finetune_trainer.py"
]
+
argv
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
main
()
main
()
# Check metrics
return
output_dir
logs
=
load_json
(
os
.
path
.
join
(
output_dir
,
"log_history.json"
))
eval_metrics
=
[
log
for
log
in
logs
if
"eval_loss"
in
log
.
keys
()]
first_step_stats
=
eval_metrics
[
0
]
last_step_stats
=
eval_metrics
[
-
1
]
assert
first_step_stats
[
"eval_bleu"
]
<
last_step_stats
[
"eval_bleu"
]
# model learned nothing
assert
isinstance
(
last_step_stats
[
"eval_bleu"
],
float
)
# test if do_predict saves generations and metrics
contents
=
os
.
listdir
(
output_dir
)
contents
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
assert
"test_generations.txt"
in
contents
assert
"test_results.json"
in
contents
examples/seq2seq/utils.py
View file @
48f23f92
...
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
...
@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
par
.
requires_grad
=
False
par
.
requires_grad
=
False
def
freeze_embeds
(
model
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type
=
model
.
config
.
model_type
if
model_type
==
"t5"
:
freeze_params
(
model
.
shared
)
for
d
in
[
model
.
encoder
,
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
elif
model_type
==
"fsmt"
:
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
else
:
freeze_params
(
model
.
model
.
shared
)
for
d
in
[
model
.
model
.
encoder
,
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
def
grad_status
(
model
:
nn
.
Module
)
->
Iterable
:
def
grad_status
(
model
:
nn
.
Module
)
->
Iterable
:
return
(
par
.
requires_grad
for
par
in
model
.
parameters
())
return
(
par
.
requires_grad
for
par
in
model
.
parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment