Unverified Commit c76de105 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add generate kwargs to Seq2SeqTrainingArguments (#13339)

* Add generate kwargs to Seq2SeqTrainingArguments

* typo

* Address review comments + doc

* Style
parent 702f4a49
...@@ -556,12 +556,15 @@ def main(): ...@@ -556,12 +556,15 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
...@@ -572,10 +575,7 @@ def main(): ...@@ -572,10 +575,7 @@ def main():
logger.info("*** Predict ***") logger.info("*** Predict ***")
predict_results = trainer.predict( predict_results = trainer.predict(
predict_dataset, predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
) )
metrics = predict_results.metrics metrics = predict_results.metrics
max_predict_samples = ( max_predict_samples = (
......
...@@ -549,12 +549,16 @@ def main(): ...@@ -549,12 +549,16 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate( metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
...@@ -565,10 +569,7 @@ def main(): ...@@ -565,10 +569,7 @@ def main():
logger.info("*** Predict ***") logger.info("*** Predict ***")
predict_results = trainer.predict( predict_results = trainer.predict(
predict_dataset, predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
) )
metrics = predict_results.metrics metrics = predict_results.metrics
max_predict_samples = ( max_predict_samples = (
......
...@@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state. dictionary also contains the epoch number which comes from the training state.
""" """
if max_length is not None or not hasattr(self, "_max_length"): self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._max_length = max_length self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict( def predict(
...@@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels). contained labels).
""" """
if max_length is not None or not hasattr(self, "_max_length"): self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._max_length = max_length self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( def prediction_step(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments): ...@@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
the training set. the training set.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU). Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_max_length (:obj:`int`, `optional`):
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
the :obj:`max_length` value of the model configuration.
generation_num_beams (:obj:`int`, `optional`):
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
:obj:`num_beams` value of the model configuration.
""" """
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field( predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
) )
generation_max_length: Optional[int] = field(
default=None,
metadata={
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `max_length` value of the model configuration."
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `num_beams` value of the model configuration."
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment