"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "576cd45a57501856bf719f2eb6186325fb6d2f88"
Unverified Commit d31497b1 authored by regisss's avatar regisss Committed by GitHub
Browse files

Do not log the generation config for each prediction step in TrainerSeq2Seq (#21385)

Do not log the generation config for each iteration
parent 98d40fed
...@@ -199,6 +199,11 @@ class Seq2SeqTrainer(Trainer): ...@@ -199,6 +199,11 @@ class Seq2SeqTrainer(Trainer):
generation_inputs, generation_inputs,
**gen_kwargs, **gen_kwargs,
) )
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
if self.model.generation_config._from_model_config:
self.model.generation_config._from_model_config = False
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment