Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e6c1f1ca
Unverified
Commit
e6c1f1ca
authored
Dec 22, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 22, 2020
Browse files
Revert renaming in finetune_trainer (#9262)
parent
ab177588
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
15 deletions
+25
-15
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+16
-8
examples/seq2seq/test_finetune_trainer.py
examples/seq2seq/test_finetune_trainer.py
+2
-2
examples/seq2seq/train_distil_marian_enro.sh
examples/seq2seq/train_distil_marian_enro.sh
+2
-1
examples/seq2seq/train_distil_marian_enro_tpu.sh
examples/seq2seq/train_distil_marian_enro_tpu.sh
+2
-1
examples/seq2seq/train_distilbart_cnn.sh
examples/seq2seq/train_distilbart_cnn.sh
+1
-1
examples/seq2seq/train_mbart_cc25_enro.sh
examples/seq2seq/train_mbart_cc25_enro.sh
+1
-1
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+1
-1
No files found.
examples/seq2seq/finetune_trainer.py
View file @
e6c1f1ca
...
...
@@ -93,19 +93,27 @@ class DataTrainingArguments:
"than this will be truncated, sequences shorter will be padded."
},
)
max_length
:
Optional
[
int
]
=
field
(
max_
target_
length
:
Optional
[
int
]
=
field
(
default
=
128
,
metadata
=
{
"help"
:
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
e
val_max_length
:
Optional
[
int
]
=
field
(
val_max_
target_
length
:
Optional
[
int
]
=
field
(
default
=
142
,
metadata
=
{
"help"
:
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. "
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
test_max_target_length
:
Optional
[
int
]
=
field
(
default
=
142
,
metadata
=
{
"help"
:
"The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
},
)
n_train
:
Optional
[
int
]
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"# training examples. -1 means use all."
})
...
...
@@ -233,7 +241,7 @@ def main():
type_path
=
"train"
,
data_dir
=
data_args
.
data_dir
,
n_obs
=
data_args
.
n_train
,
max_target_length
=
data_args
.
max_length
,
max_target_length
=
data_args
.
max_
target_
length
,
max_source_length
=
data_args
.
max_source_length
,
prefix
=
model
.
config
.
prefix
or
""
,
)
...
...
@@ -246,7 +254,7 @@ def main():
type_path
=
"val"
,
data_dir
=
data_args
.
data_dir
,
n_obs
=
data_args
.
n_val
,
max_target_length
=
data_args
.
e
val_max_length
,
max_target_length
=
data_args
.
val_max_
target_
length
,
max_source_length
=
data_args
.
max_source_length
,
prefix
=
model
.
config
.
prefix
or
""
,
)
...
...
@@ -259,7 +267,7 @@ def main():
type_path
=
"test"
,
data_dir
=
data_args
.
data_dir
,
n_obs
=
data_args
.
n_test
,
max_target_length
=
data_args
.
eval_max
_length
,
max_target_length
=
data_args
.
test_max_target
_length
,
max_source_length
=
data_args
.
max_source_length
,
prefix
=
model
.
config
.
prefix
or
""
,
)
...
...
@@ -310,7 +318,7 @@ def main():
logger
.
info
(
"*** Evaluate ***"
)
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"val"
,
max_length
=
data_args
.
e
val_max_length
,
num_beams
=
data_args
.
eval_beams
metric_key_prefix
=
"val"
,
max_length
=
data_args
.
val_max_
target_
length
,
num_beams
=
data_args
.
eval_beams
)
metrics
[
"val_n_objs"
]
=
data_args
.
n_val
metrics
[
"val_loss"
]
=
round
(
metrics
[
"val_loss"
],
4
)
...
...
@@ -326,7 +334,7 @@ def main():
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
,
metric_key_prefix
=
"test"
,
max_length
=
data_args
.
e
val_max_length
,
max_length
=
data_args
.
val_max_
target_
length
,
num_beams
=
data_args
.
eval_beams
,
)
metrics
=
test_output
.
metrics
...
...
examples/seq2seq/test_finetune_trainer.py
View file @
e6c1f1ca
...
...
@@ -137,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
--n_train 8
--n_val 8
--max_source_length
{
max_len
}
--max_length
{
max_len
}
--
e
val_max_length
{
max_len
}
--max_
target_
length
{
max_len
}
--val_max_
target_
length
{
max_len
}
--do_train
--do_eval
--do_predict
...
...
examples/seq2seq/train_distil_marian_enro.sh
View file @
e6c1f1ca
...
...
@@ -29,7 +29,8 @@ python finetune_trainer.py \
--freeze_encoder
--freeze_embeds
\
--num_train_epochs
=
6
\
--save_steps
3000
--eval_steps
3000
\
--max_source_length
$MAX_LEN
--max_length
$MAX_LEN
--eval_max_length
$MAX_LEN
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
\
--val_max_target_length
$MAX_TGT_LEN
--test_max_target_length
$MAX_TGT_LEN
\
--do_train
--do_eval
--do_predict
\
--evaluation_strategy
steps
\
--predict_with_generate
--logging_first_step
\
...
...
examples/seq2seq/train_distil_marian_enro_tpu.sh
View file @
e6c1f1ca
...
...
@@ -30,7 +30,8 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
--num_train_epochs
=
6
\
--save_steps
500
--eval_steps
500
\
--logging_first_step
--logging_steps
200
\
--max_source_length
$MAX_LEN
--max_length
$MAX_LEN
--eval_max_length
$MAX_LEN
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
\
--val_max_target_length
$MAX_TGT_LEN
--test_max_target_length
$MAX_TGT_LEN
\
--do_train
--do_eval
\
--evaluation_strategy
steps
\
--prediction_loss_only
\
...
...
examples/seq2seq/train_distilbart_cnn.sh
View file @
e6c1f1ca
...
...
@@ -32,7 +32,7 @@ python finetune_trainer.py \
--num_train_epochs
=
2
\
--save_steps
3000
--eval_steps
3000
\
--logging_first_step
\
--max_length
56
--
e
val_max_length
$MAX_TGT_LEN
\
--max_
target_
length
56
--val_max_
target_
length
$MAX_TGT_LEN
--test_max_target_length
$MAX_TGT_LEN
\
--do_train
--do_eval
--do_predict
\
--evaluation_strategy
steps
\
--predict_with_generate
--sortish_sampler
\
...
...
examples/seq2seq/train_mbart_cc25_enro.sh
View file @
e6c1f1ca
...
...
@@ -24,7 +24,7 @@ python finetune_trainer.py \
--src_lang
en_XX
--tgt_lang
ro_RO
\
--freeze_embeds
\
--per_device_train_batch_size
=
4
--per_device_eval_batch_size
=
4
\
--max_source_length
128
--max_length
128
--
e
val_max_length
128
\
--max_source_length
128
--max_
target_
length
128
--val_max_
target_
length
128
--test_max_target_length
128
\
--sortish_sampler
\
--num_train_epochs
6
\
--save_steps
25000
--eval_steps
25000
--logging_steps
1000
\
...
...
examples/seq2seq/utils.py
View file @
e6c1f1ca
...
...
@@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
[
x
[
"src_texts"
]
for
x
in
batch
],
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
max_length
=
self
.
data_args
.
max_source_length
,
max_target_length
=
self
.
data_args
.
max_length
,
max_target_length
=
self
.
data_args
.
max_
target_
length
,
padding
=
"max_length"
if
self
.
tpu_num_cores
is
not
None
else
"longest"
,
# TPU hack
return_tensors
=
"pt"
,
**
self
.
dataset_kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment