"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "9fde0e8e003d894003ebafa5a30160f5fb9dc2c2"
Unverified Commit 738944c9 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Trainer: missing None check (#22404)

missing None check
parent 53155b52
...@@ -281,7 +281,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -281,7 +281,7 @@ class Seq2SeqTrainer(Trainer):
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_config.max_length: if generated_tokens.shape[-1] < gen_config.max_length:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
elif generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment