Unverified Commit 12e02e33 authored by cchen-dialpad's avatar cchen-dialpad Committed by GitHub
Browse files

`Seq2SeqTrainer` set max_length and num_beams only when non None (#12899)

* set max_length and num_beams only when non None

* fix instance variables

* fix code style
parent ba15fe79
......@@ -70,7 +70,9 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
......@@ -117,7 +119,9 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment