Unverified Commit 738944c9 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Trainer: missing None check (#22404)

missing None check
parent 53155b52
......@@ -281,7 +281,7 @@ class Seq2SeqTrainer(Trainer):
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
elif generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment