Unverified Commit e823d819 authored by Haram Lee's avatar Haram Lee Committed by GitHub
Browse files

Add a condition for checking labels (#14211)

parent b3385963
...@@ -196,9 +196,12 @@ class Seq2SeqTrainer(Trainer): ...@@ -196,9 +196,12 @@ class Seq2SeqTrainer(Trainer):
if self.args.prediction_loss_only: if self.args.prediction_loss_only:
return (loss, None, None) return (loss, None, None)
labels = inputs["labels"] if has_labels:
if labels.shape[-1] < gen_kwargs["max_length"]: labels = inputs["labels"]
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
else:
labels = None
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment