Unverified Commit e6c1f1ca authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Revert renaming in finetune_trainer (#9262)

parent ab177588
......@@ -93,19 +93,27 @@ class DataTrainingArguments:
"than this will be truncated, sequences shorter will be padded."
},
)
max_length: Optional[int] = field(
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
eval_max_length: Optional[int] = field(
val_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. "
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
......@@ -233,7 +241,7 @@ def main():
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_length,
max_target_length=data_args.max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
......@@ -246,7 +254,7 @@ def main():
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_val,
max_target_length=data_args.eval_max_length,
max_target_length=data_args.val_max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
......@@ -259,7 +267,7 @@ def main():
type_path="test",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.eval_max_length,
max_target_length=data_args.test_max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
......@@ -310,7 +318,7 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
......@@ -326,7 +334,7 @@ def main():
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.eval_max_length,
max_length=data_args.val_max_target_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics
......
......@@ -137,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_length {max_len}
--eval_max_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
......
......@@ -29,7 +29,8 @@ python finetune_trainer.py \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 3000 --eval_steps 3000 \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN \
--val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --logging_first_step \
......
......@@ -30,7 +30,8 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
--num_train_epochs=6 \
--save_steps 500 --eval_steps 500 \
--logging_first_step --logging_steps 200 \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN \
--val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--do_train --do_eval \
--evaluation_strategy steps \
--prediction_loss_only \
......
......@@ -32,7 +32,7 @@ python finetune_trainer.py \
--num_train_epochs=2 \
--save_steps 3000 --eval_steps 3000 \
--logging_first_step \
--max_length 56 --eval_max_length $MAX_TGT_LEN \
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN\
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --sortish_sampler \
......
......@@ -24,7 +24,7 @@ python finetune_trainer.py \
--src_lang en_XX --tgt_lang ro_RO \
--freeze_embeds \
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_length 128 --eval_max_length 128 \
--max_source_length 128 --max_target_length 128 --val_max_target_length 128 --test_max_target_length 128\
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
......
......@@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment