Unverified Commit 95a90410 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_finetune_bert2bert` (#25984)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 86ffef87
...@@ -281,15 +281,16 @@ class Seq2SeqTrainer(Trainer): ...@@ -281,15 +281,16 @@ class Seq2SeqTrainer(Trainer):
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
) )
generation_inputs = inputs.copy()
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
# (otherwise, it would continue generating from the padded `decoder_input_ids`) # (otherwise, it would continue generating from the padded `decoder_input_ids`)
if ( if (
"labels" in inputs "labels" in generation_inputs
and "decoder_input_ids" in inputs and "decoder_input_ids" in generation_inputs
and inputs["labels"].shape == inputs["decoder_input_ids"].shape and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
): ):
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**inputs, **gen_kwargs) generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment