Unverified Commit 3a2ed967 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix Seq2SeqTrainer (#15603)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
parent 724e51c6
...@@ -161,6 +161,9 @@ class Seq2SeqTrainer(Trainer): ...@@ -161,6 +161,9 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
} }
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
# prepare generation inputs # prepare generation inputs
# some encoder-decoder models can have varying encder's and thus # some encoder-decoder models can have varying encder's and thus
# varying model input names # varying model input names
...@@ -171,7 +174,6 @@ class Seq2SeqTrainer(Trainer): ...@@ -171,7 +174,6 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
generation_inputs, generation_inputs,
attention_mask=inputs.get("attention_mask", None),
**gen_kwargs, **gen_kwargs,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment