Unverified Commit 8e74eca7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

push (#9320)

parent 61443cd7
...@@ -171,7 +171,9 @@ class Seq2SeqTrainer(Trainer): ...@@ -171,7 +171,9 @@ class Seq2SeqTrainer(Trainer):
""" """
if not self.args.predict_with_generate or prediction_loss_only: if not self.args.predict_with_generate or prediction_loss_only:
return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys) return super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
has_labels = "labels" in inputs has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment