Unverified Commit 95a90410 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_finetune_bert2bert` (#25984)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 86ffef87
......@@ -281,15 +281,16 @@ class Seq2SeqTrainer(Trainer):
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
generation_inputs = inputs.copy()
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
# (otherwise, it would continue generating from the padded `decoder_input_ids`)
if (
"labels" in inputs
and "decoder_input_ids" in inputs
and inputs["labels"].shape == inputs["decoder_input_ids"].shape
"labels" in generation_inputs
and "decoder_input_ids" in generation_inputs
and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
):
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment