Unverified Commit 12e02e33 authored by cchen-dialpad's avatar cchen-dialpad Committed by GitHub
Browse files

`Seq2SeqTrainer` set max_length and num_beams only when non None (#12899)

* set max_length and num_beams only when non None

* fix instance variables

* fix code style
parent ba15fe79
...@@ -70,8 +70,10 @@ class Seq2SeqTrainer(Trainer): ...@@ -70,8 +70,10 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state. dictionary also contains the epoch number which comes from the training state.
""" """
self._max_length = max_length if max_length is not None or not hasattr(self, "_max_length"):
self._num_beams = num_beams self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict( def predict(
...@@ -117,8 +119,10 @@ class Seq2SeqTrainer(Trainer): ...@@ -117,8 +119,10 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels). contained labels).
""" """
self._max_length = max_length if max_length is not None or not hasattr(self, "_max_length"):
self._num_beams = num_beams self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( def prediction_step(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment