Unverified Commit 34678db4 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix Seq2seqTrainer decoder attention mask (#26841)

Don't drop decoder_input_ids without also dropping decoder_attention_mask
parent 280c757f
...@@ -288,7 +288,9 @@ class Seq2SeqTrainer(Trainer): ...@@ -288,7 +288,9 @@ class Seq2SeqTrainer(Trainer):
and "decoder_input_ids" in generation_inputs and "decoder_input_ids" in generation_inputs
and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
): ):
generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} generation_inputs = {
k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
}
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment