Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
72d363d9
Unverified
Commit
72d363d9
authored
Oct 01, 2020
by
Suraj Patil
Committed by
GitHub
Oct 01, 2020
Browse files
[examples/s2s] clean up finetune_trainer (#7509)
parent
bd262158
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
105 deletions
+107
-105
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+8
-100
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+5
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+94
-2
No files found.
examples/seq2seq/finetune_trainer.py
View file @
72d363d9
...
...
@@ -2,37 +2,29 @@ import logging
import
os
import
sys
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
typing
import
Optional
from
seq2seq_trainer
import
Seq2SeqTrainer
from
transformers
import
(
AutoConfig
,
AutoModelForSeq2SeqLM
,
AutoTokenizer
,
BartTokenizer
,
EvalPrediction
,
HfArgumentParser
,
MBartTokenizer
,
T5Tokenizer
,
TrainingArguments
,
set_seed
,
)
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.trainer_utils
import
EvaluationStrategy
from
utils
import
(
LegacySeq2SeqDataset
,
Seq2SeqDataCollator
,
Seq2SeqDataset
,
assert_all_frozen
,
calculate_bleu
,
calculate_rouge
,
build_compute_metrics_fn
,
freeze_embeds
,
freeze_params
,
lmap
,
save_json
,
trim_batch
,
use_task_specific_params
,
write_txt_file
,
)
...
...
@@ -41,66 +33,6 @@ from utils import (
logger
=
logging
.
getLogger
(
__name__
)
class
Seq2SeqDataCollator
:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
self
.
pad_token_id
is
not
None
,
"self.pad_token_id must be defined"
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
def
__call__
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
):
batch
=
self
.
_encode
(
batch
)
input_ids
,
attention_mask
,
labels
=
(
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"labels"
],
)
else
:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
attention_mask
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
labels
=
torch
.
stack
([
x
[
"labels"
]
for
x
in
batch
])
labels
=
trim_batch
(
labels
,
self
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
input_ids
,
self
.
pad_token_id
,
attention_mask
=
attention_mask
)
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
batch
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"decoder_input_ids"
:
decoder_input_ids
,
"labels"
:
labels
,
}
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
data_args
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
data_args
.
tgt_lang
,
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_target_length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
)
return
batch_encoding
.
data
@
dataclass
class
Seq2SeqTrainingArguments
(
TrainingArguments
):
"""
...
...
@@ -271,34 +203,6 @@ def main():
),
"mBart requires --tgt_lang and --src_lang"
model
.
config
.
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
return
np
.
count_nonzero
(
tokens
!=
tokenizer
.
pad_token_id
)
def
decode_pred
(
pred
:
EvalPrediction
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
pred_str
=
tokenizer
.
batch_decode
(
pred
.
predictions
,
skip_special_tokens
=
True
)
label_str
=
tokenizer
.
batch_decode
(
pred
.
label_ids
,
skip_special_tokens
=
True
)
pred_str
=
lmap
(
str
.
strip
,
pred_str
)
label_str
=
lmap
(
str
.
strip
,
label_str
)
return
pred_str
,
label_str
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
if
model_args
.
freeze_embeds
:
freeze_embeds
(
model
)
if
model_args
.
freeze_encoder
:
...
...
@@ -349,13 +253,17 @@ def main():
)
# Initialize our Trainer
compute_metrics_fn
=
(
build_compute_metrics_fn
(
data_args
.
task
,
tokenizer
)
if
training_args
.
predict_with_generate
else
None
)
trainer
=
Seq2SeqTrainer
(
model
=
model
,
config
=
config
,
args
=
training_args
,
train_dataset
=
train_dataset
,
eval_dataset
=
eval_dataset
,
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
compute_metrics
=
build_
compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
compute_metrics
=
compute_metrics_fn
,
data_args
=
data_args
,
)
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
72d363d9
...
...
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
data_args
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
config
,
data_args
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
data_args
=
data_args
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
pad_token_id
=
self
.
model
.
config
.
pad_token_id
self
.
pad_token_id
=
self
.
config
.
pad_token_id
self
.
vocab_size
=
self
.
config
.
vocab_size
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
...
...
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
if
self
.
args
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
)
assert
logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
...
...
examples/seq2seq/utils.py
View file @
72d363d9
...
...
@@ -7,7 +7,7 @@ import pickle
import
socket
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
import
git
import
numpy
as
np
...
...
@@ -19,8 +19,9 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
Sampler
from
sentence_splitter
import
add_newline_to_end_of_each_sentence
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
,
EvalPrediction
,
PreTrainedTokenizer
,
T5Tokenizer
from
transformers.file_utils
import
cached_property
from
transformers.modeling_bart
import
shift_tokens_right
try
:
...
...
@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
return
{
"bleu"
:
round
(
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
,
4
)}
def
build_compute_metrics_fn
(
task_name
:
str
,
tokenizer
:
PreTrainedTokenizer
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
return
np
.
count_nonzero
(
tokens
!=
tokenizer
.
pad_token_id
)
def
decode_pred
(
pred
:
EvalPrediction
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
pred_str
=
tokenizer
.
batch_decode
(
pred
.
predictions
,
skip_special_tokens
=
True
)
label_str
=
tokenizer
.
batch_decode
(
pred
.
label_ids
,
skip_special_tokens
=
True
)
pred_str
=
lmap
(
str
.
strip
,
pred_str
)
label_str
=
lmap
(
str
.
strip
,
label_str
)
return
pred_str
,
label_str
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
def
trim_batch
(
input_ids
,
pad_token_id
,
...
...
@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
return
batch_encoding
class
Seq2SeqDataCollator
:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
(
self
.
pad_token_id
is
not
None
),
f
"pad_token_id is not defined for (
{
self
.
tokenizer
.
__class__
.
__name__
}
), it must be defined."
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
def
__call__
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
):
batch
=
self
.
_encode
(
batch
)
input_ids
,
attention_mask
,
labels
=
(
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"labels"
],
)
else
:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
attention_mask
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
labels
=
torch
.
stack
([
x
[
"labels"
]
for
x
in
batch
])
labels
=
trim_batch
(
labels
,
self
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
input_ids
,
self
.
pad_token_id
,
attention_mask
=
attention_mask
)
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
batch
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"decoder_input_ids"
:
decoder_input_ids
,
"labels"
:
labels
,
}
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
data_args
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
data_args
.
tgt_lang
,
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_target_length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
)
return
batch_encoding
.
data
class
SortishSampler
(
Sampler
):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment