Unverified Commit 3a2ed967 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix Seq2SeqTrainer (#15603)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
parent 724e51c6
......@@ -161,6 +161,9 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
# prepare generation inputs
# some encoder-decoder models can have varying encder's and thus
# varying model input names
......@@ -171,7 +174,6 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self.model.generate(
generation_inputs,
attention_mask=inputs.get("attention_mask", None),
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment