"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3eceaa3637197fa78dd3525cb3df57fcaf5ba00d"
Unverified Commit 12e02e33 authored by cchen-dialpad's avatar cchen-dialpad Committed by GitHub
Browse files

`Seq2SeqTrainer` set max_length and num_beams only when non None (#12899)

* set max_length and num_beams only when non None

* fix instance variables

* fix code style
parent ba15fe79
......@@ -70,8 +70,10 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
self._max_length = max_length
self._num_beams = num_beams
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
......@@ -117,8 +119,10 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
self._max_length = max_length
self._num_beams = num_beams
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment