Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
99cb924b
Unverified
Commit
99cb924b
authored
Oct 04, 2020
by
Suraj Patil
Committed by
GitHub
Oct 04, 2020
Browse files
[s2s] add config params like Dropout in Seq2SeqTrainingArguments (#7532)
parent
9bdce3a4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
18 deletions
+35
-18
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+17
-0
examples/seq2seq/seq2seq_trainer.py
examples/seq2seq/seq2seq_trainer.py
+11
-13
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+1
-1
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+6
-4
No files found.
examples/seq2seq/finetune_trainer.py
View file @
99cb924b
...
...
@@ -53,6 +53,16 @@ class Seq2SeqTrainingArguments(TrainingArguments):
default
=
False
,
metadata
=
{
"help"
:
"Whether to use generate to calculate generative metrics (ROUGE, BLEU)."
}
)
adafactor
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use adafactor"
})
encoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Encoder layer dropout probability. Goes into model.config."
}
)
decoder_layerdrop
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Decoder layer dropout probability. Goes into model.config."
}
)
dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Dropout probability. Goes into model.config."
})
attention_dropout
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Attention dropout probability. Goes into model.config."
}
)
@
dataclass
...
...
@@ -179,6 +189,13 @@ def main():
model_args
.
config_name
if
model_args
.
config_name
else
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
for
p
in
extra_model_params
:
if
getattr
(
training_args
,
p
,
None
):
assert
hasattr
(
config
,
p
),
f
"(
{
config
.
__class__
.
__name__
}
) doesn't have a `
{
p
}
` attribute"
setattr
(
config
,
p
,
getattr
(
training_args
,
p
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
tokenizer_name
if
model_args
.
tokenizer_name
else
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
...
...
examples/seq2seq/seq2seq_trainer.py
View file @
99cb924b
...
...
@@ -6,6 +6,7 @@ from torch import nn
from
torch.utils.data
import
DistributedSampler
,
RandomSampler
from
transformers
import
Trainer
from
transformers.configuration_fsmt
import
FSMTConfig
from
transformers.file_utils
import
is_torch_tpu_available
from
transformers.optimization
import
Adafactor
,
AdamW
,
get_linear_schedule_with_warmup
from
transformers.trainer
import
get_tpu_sampler
...
...
@@ -26,8 +27,7 @@ class Seq2SeqTrainer(Trainer):
self
.
config
=
config
self
.
data_args
=
data_args
self
.
max_gen_length
=
data_args
.
val_max_target_length
self
.
pad_token_id
=
self
.
config
.
pad_token_id
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
vocab_size
=
self
.
config
.
tgt_vocab_size
if
isinstance
(
self
.
config
,
FSMTConfig
)
else
self
.
config
.
vocab_size
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
):
"""
...
...
@@ -87,18 +87,18 @@ class Seq2SeqTrainer(Trainer):
labels
=
inputs
.
pop
(
"labels"
)
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
logits
=
outputs
[
0
]
return
self
.
_compute_loss
(
logits
,
labels
,
ignore_index
=
self
.
pad_token_id
)
return
self
.
_compute_loss
(
logits
,
labels
)
def
_compute_loss
(
self
,
logits
,
labels
,
ignore_index
):
def
_compute_loss
(
self
,
logits
,
labels
):
if
self
.
args
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
)
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
self
.
config
.
pad_token_id
)
assert
logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
]),
labels
.
view
(
-
1
))
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
labels
,
self
.
args
.
label_smoothing
,
ignore_index
=
ignore_index
lprobs
,
labels
,
self
.
args
.
label_smoothing
,
ignore_index
=
self
.
config
.
pad_token_id
)
return
loss
...
...
@@ -137,14 +137,12 @@ class Seq2SeqTrainer(Trainer):
max_length
=
self
.
max_gen_length
,
)
# in case the batch is shorter than max length, the output should be padded
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
self
.
max_gen_length
,
self
.
pad_token_id
)
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
self
.
max_gen_length
)
labels_out
=
inputs
.
get
(
"labels"
)
# Call forward again to get loss # TODO: avoidable?
outputs
=
model
(
**
inputs
,
use_cache
=
False
)
loss
=
self
.
_compute_loss
(
outputs
[
1
],
labels_out
,
self
.
pad_token_id
)
loss
=
self
.
_compute_loss
(
outputs
[
1
],
labels_out
)
loss
=
loss
.
mean
().
item
()
if
self
.
args
.
prediction_loss_only
:
return
(
loss
,
None
,
None
)
...
...
@@ -152,11 +150,11 @@ class Seq2SeqTrainer(Trainer):
logits
=
generated_tokens
if
self
.
args
.
predict_with_generate
else
outputs
[
1
]
labels_out
=
labels_out
.
detach
()
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
self
.
max_gen_length
,
self
.
pad_token_id
)
labels
=
self
.
_pad_tensors_to_max_len
(
labels_out
,
self
.
max_gen_length
)
return
(
loss
,
logits
.
detach
(),
labels
)
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
,
pad_token_id
):
padded_tensor
=
pad_token_id
*
torch
.
ones
(
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
padded_tensor
=
self
.
config
.
pad_token_id
*
torch
.
ones
(
(
tensor
.
shape
[
0
],
max_length
),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
padded_tensor
[:,
:
tensor
.
shape
[
-
1
]]
=
tensor
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
99cb924b
...
...
@@ -26,7 +26,7 @@ def test_finetune_trainer():
def
test_finetune_trainer_slow
():
# TODO(SS): This will fail on devices with more than 1 GPU.
# There is a missing call to __init__process_group somewhere
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"
32
"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
output_dir
=
run_trainer
(
eval_steps
=
2
,
max_len
=
"
128
"
,
model_name
=
MARIAN_MODEL
,
num_train_epochs
=
3
)
# Check metrics
logs
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
)).
log_history
...
...
examples/seq2seq/utils.py
View file @
99cb924b
...
...
@@ -269,7 +269,11 @@ class Seq2SeqDataCollator:
),
f
"pad_token_id is not defined for (
{
self
.
tokenizer
.
__class__
.
__name__
}
), it must be defined."
self
.
data_args
=
data_args
self
.
tpu_num_cores
=
tpu_num_cores
self
.
add_prefix_space
=
isinstance
(
tokenizer
,
BartTokenizer
)
self
.
dataset_kwargs
=
{
"add_prefix_space"
:
isinstance
(
tokenizer
,
BartTokenizer
)}
if
data_args
.
src_lang
is
not
None
:
self
.
dataset_kwargs
[
"src_lang"
]
=
data_args
.
src_lang
if
data_args
.
tgt_lang
is
not
None
:
self
.
dataset_kwargs
[
"tgt_lang"
]
=
data_args
.
tgt_lang
def
__call__
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
):
...
...
@@ -310,14 +314,12 @@ class Seq2SeqDataCollator:
def
_encode
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
data_args
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
data_args
.
tgt_lang
,
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_target_length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
**
self
.
dataset_kwargs
,
)
return
batch_encoding
.
data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment