"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f435003e0c2dd152a2117d11c0ab6fcd4f2d84c0"
Unverified Commit 48706c71 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Seq2SeqTrainer: use unwrapped model to retrieve the generation config (#22584)

parent 0aa1153f
...@@ -277,7 +277,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -277,7 +277,7 @@ class Seq2SeqTrainer(Trainer):
self.model.generation_config._from_model_config = False self.model.generation_config._from_model_config = False
# Retrieves GenerationConfig from model.generation_config # Retrieves GenerationConfig from model.generation_config
gen_config = model.generation_config gen_config = self.model.generation_config
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_config.max_length: if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment