"...resnet50_tensorflow.git" did not exist on "383c6e309dba446642a6a681af82474e8558d5f2"
Unverified Commit 895ae3b5 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Seq2SeqTrainer: Evict decoder_input_ids only when it is created from labels (#22772)

parent daf53241
......@@ -265,9 +265,14 @@ class Seq2SeqTrainer(Trainer):
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
# TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block
# users from preparing a dataset with `decoder_input_ids`.
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
# (otherwise, it would continue generating from the padded `decoder_input_ids`)
if (
"labels" in inputs
and "decoder_input_ids" in inputs
and inputs["labels"].shape == inputs["decoder_input_ids"].shape
):
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment