Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
72d363d9
Unverified
Commit
72d363d9
authored
Oct 01, 2020
by
Suraj Patil
Committed by
GitHub
Oct 01, 2020
Browse files
[examples/s2s] clean up finetune_trainer (#7509)
parent
bd262158
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
105 deletions
+107
-105
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+8
-100
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+5
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+94
-2
No files found.
examples/seq2seq/finetune_trainer.py
View file @
72d363d9
...
@@ -2,37 +2,29 @@ import logging
...
@@ -2,37 +2,29 @@ import logging
import
os
import
os
import
sys
import
sys
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
import
numpy
as
np
import
torch
from
seq2seq_trainer
import
Seq2SeqTrainer
from
seq2seq_trainer
import
Seq2SeqTrainer
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoTokenizer
,
AutoTokenizer
,
BartTokenizer
,
EvalPrediction
,
HfArgumentParser
,
HfArgumentParser
,
MBartTokenizer
,
MBartTokenizer
,
T5Tokenizer
,
TrainingArguments
,
TrainingArguments
,
set_seed
,
set_seed
,
)
)
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.trainer_utils
import
EvaluationStrategy
from
transformers.trainer_utils
import
EvaluationStrategy
from
utils
import
(
from
utils
import
(
LegacySeq2SeqDataset
,
LegacySeq2SeqDataset
,
Seq2SeqDataCollator
,
Seq2SeqDataset
,
Seq2SeqDataset
,
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
,
build_compute_metrics_fn
,
calculate_rouge
,
freeze_embeds
,
freeze_embeds
,
freeze_params
,
freeze_params
,
lmap
,
lmap
,
save_json
,
save_json
,
trim_batch
,
use_task_specific_params
,
use_task_specific_params
,
write_txt_file
,
write_txt_file
,
)
)
...
@@ -41,66 +33,6 @@ from utils import (
...
@@ -41,66 +33,6 @@ from utils import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
Seq2SeqDataCollator
:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
self
.
pad_token_id
is
not
None
,
"self.pad_token_id must be defined"
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
def
__call__
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
):
batch
=
self
.
_encode
(
batch
)
input_ids
,
attention_mask
,
labels
=
(
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"labels"
],
)
else
:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
attention_mask
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
labels
=
torch
.
stack
([
x
[
"labels"
]
for
x
in
batch
])
labels
=
trim_batch
(
labels
,
self
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
input_ids
,
self
.
pad_token_id
,
attention_mask
=
attention_mask
)
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
batch
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"decoder_input_ids"
:
decoder_input_ids
,
"labels"
:
labels
,
}
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
data_args
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
data_args
.
tgt_lang
,
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_target_length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
)
return
batch_encoding
.
data
@
dataclass
@
dataclass
class
Seq2SeqTrainingArguments
(
TrainingArguments
):
class
Seq2SeqTrainingArguments
(
TrainingArguments
):
"""
"""
...
@@ -271,34 +203,6 @@ def main():
...
@@ -271,34 +203,6 @@ def main():
),
"mBart requires --tgt_lang and --src_lang"
),
"mBart requires --tgt_lang and --src_lang"
model
.
config
.
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
model
.
config
.
decoder_start_token_id
=
tokenizer
.
lang_code_to_id
[
data_args
.
tgt_lang
]
def
build_compute_metrics_fn
(
task_name
:
str
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
return
np
.
count_nonzero
(
tokens
!=
tokenizer
.
pad_token_id
)
def
decode_pred
(
pred
:
EvalPrediction
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
pred_str
=
tokenizer
.
batch_decode
(
pred
.
predictions
,
skip_special_tokens
=
True
)
label_str
=
tokenizer
.
batch_decode
(
pred
.
label_ids
,
skip_special_tokens
=
True
)
pred_str
=
lmap
(
str
.
strip
,
pred_str
)
label_str
=
lmap
(
str
.
strip
,
label_str
)
return
pred_str
,
label_str
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
if
model_args
.
freeze_embeds
:
if
model_args
.
freeze_embeds
:
freeze_embeds
(
model
)
freeze_embeds
(
model
)
if
model_args
.
freeze_encoder
:
if
model_args
.
freeze_encoder
:
...
@@ -349,13 +253,17 @@ def main():
...
@@ -349,13 +253,17 @@ def main():
)
)
# Initialize our Trainer
# Initialize our Trainer
compute_metrics_fn
=
(
build_compute_metrics_fn
(
data_args
.
task
,
tokenizer
)
if
training_args
.
predict_with_generate
else
None
)
trainer
=
Seq2SeqTrainer
(
trainer
=
Seq2SeqTrainer
(
model
=
model
,
model
=
model
,
config
=
config
,
args
=
training_args
,
args
=
training_args
,
train_dataset
=
train_dataset
,
train_dataset
=
train_dataset
,
eval_dataset
=
eval_dataset
,
eval_dataset
=
eval_dataset
,
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
data_collator
=
Seq2SeqDataCollator
(
tokenizer
,
data_args
,
training_args
.
tpu_num_cores
),
compute_metrics
=
build_
compute_metrics_fn
(
data_args
.
task
)
if
training_args
.
predict_with_generate
else
None
,
compute_metrics
=
compute_metrics_fn
,
data_args
=
data_args
,
data_args
=
data_args
,
)
)
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
72d363d9
...
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
...
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
class
Seq2SeqTrainer
(
Trainer
):
class
Seq2SeqTrainer
(
Trainer
):
def
__init__
(
self
,
data_args
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
config
,
data_args
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
data_args
=
data_args
self
.
data_args
=
data_args
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
pad_token_id
=
self
.
model
.
config
.
pad_token_id
self
.
pad_token_id
=
self
.
config
.
pad_token_id
self
.
vocab_size
=
self
.
config
.
vocab_size
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
def
_get_train_sampler
(
self
)
->
Optional
[
torch
.
utils
.
data
.
sampler
.
Sampler
]:
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
if
isinstance
(
self
.
train_dataset
,
torch
.
utils
.
data
.
IterableDataset
):
...
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
...
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
if
self
.
args
.
label_smoothing
==
0
:
if
self
.
args
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py
# Same behavior as modeling_bart.py
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
)
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
)
assert
logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
else
:
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
...
...
examples/seq2seq/utils.py
View file @
72d363d9
...
@@ -7,7 +7,7 @@ import pickle
...
@@ -7,7 +7,7 @@ import pickle
import
socket
import
socket
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
import
git
import
git
import
numpy
as
np
import
numpy
as
np
...
@@ -19,8 +19,9 @@ from torch import nn
...
@@ -19,8 +19,9 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
Sampler
from
torch.utils.data
import
Dataset
,
Sampler
from
sentence_splitter
import
add_newline_to_end_of_each_sentence
from
sentence_splitter
import
add_newline_to_end_of_each_sentence
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
,
EvalPrediction
,
PreTrainedTokenizer
,
T5Tokenizer
from
transformers.file_utils
import
cached_property
from
transformers.file_utils
import
cached_property
from
transformers.modeling_bart
import
shift_tokens_right
try
:
try
:
...
@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
...
@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
return
{
"bleu"
:
round
(
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
,
4
)}
return
{
"bleu"
:
round
(
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
,
4
)}
def
build_compute_metrics_fn
(
task_name
:
str
,
tokenizer
:
PreTrainedTokenizer
)
->
Callable
[[
EvalPrediction
],
Dict
]:
def
non_pad_len
(
tokens
:
np
.
ndarray
)
->
int
:
return
np
.
count_nonzero
(
tokens
!=
tokenizer
.
pad_token_id
)
def
decode_pred
(
pred
:
EvalPrediction
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
pred_str
=
tokenizer
.
batch_decode
(
pred
.
predictions
,
skip_special_tokens
=
True
)
label_str
=
tokenizer
.
batch_decode
(
pred
.
label_ids
,
skip_special_tokens
=
True
)
pred_str
=
lmap
(
str
.
strip
,
pred_str
)
label_str
=
lmap
(
str
.
strip
,
label_str
)
return
pred_str
,
label_str
def
summarization_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
rouge
:
Dict
=
calculate_rouge
(
pred_str
,
label_str
)
summ_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
rouge
.
update
({
"gen_len"
:
summ_len
})
return
rouge
def
translation_metrics
(
pred
:
EvalPrediction
)
->
Dict
:
pred_str
,
label_str
=
decode_pred
(
pred
)
bleu
:
Dict
=
calculate_bleu
(
pred_str
,
label_str
)
gen_len
=
np
.
round
(
np
.
mean
(
lmap
(
non_pad_len
,
pred
.
predictions
)),
1
)
bleu
.
update
({
"gen_len"
:
gen_len
})
return
bleu
compute_metrics_fn
=
summarization_metrics
if
"summarization"
in
task_name
else
translation_metrics
return
compute_metrics_fn
def
trim_batch
(
def
trim_batch
(
input_ids
,
input_ids
,
pad_token_id
,
pad_token_id
,
...
@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
return
batch_encoding
return
batch_encoding
class
Seq2SeqDataCollator
:
def
__init__
(
self
,
tokenizer
,
data_args
,
tpu_num_cores
=
None
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
tokenizer
.
pad_token_id
assert
(
self
.
pad_token_id
is
not
None
),
f
"pad_token_id is not defined for (
{
self
.
tokenizer
.
__class__
.
__name__
}
), it must be defined."
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
def
__call__
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
):
batch
=
self
.
_encode
(
batch
)
input_ids
,
attention_mask
,
labels
=
(
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"labels"
],
)
else
:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
attention_mask
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
labels
=
torch
.
stack
([
x
[
"labels"
]
for
x
in
batch
])
labels
=
trim_batch
(
labels
,
self
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
input_ids
,
self
.
pad_token_id
,
attention_mask
=
attention_mask
)
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
decoder_input_ids
=
self
.
_shift_right_t5
(
labels
)
else
:
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
pad_token_id
)
batch
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"decoder_input_ids"
:
decoder_input_ids
,
"labels"
:
labels
,
}
return
batch
def
_shift_right_t5
(
self
,
input_ids
):
# shift inputs to the right
shifted_input_ids
=
input_ids
.
new_zeros
(
input_ids
.
shape
)
shifted_input_ids
[...,
1
:]
=
input_ids
[...,
:
-
1
].
clone
()
shifted_input_ids
[...,
0
]
=
self
.
pad_token_id
return
shifted_input_ids
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
data_args
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
data_args
.
tgt_lang
,
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_target_length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
)
return
batch_encoding
.
data
class
SortishSampler
(
Sampler
):
class
SortishSampler
(
Sampler
):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment