Unverified Commit ab663b22 authored by kumapo's avatar kumapo Committed by GitHub
Browse files

reflect max_new_tokens in `Seq2SeqTrainer` (#18786)

* reflect max_new_tokens in gen_kwargs to `trainer.generate()`

* reflect max_new_tokens in `Seq2SeqTrainer`

* remove unnecessary variable

* Trigger CI

* fix style
parent f719c037
...@@ -68,9 +68,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -68,9 +68,8 @@ class Seq2SeqTrainer(Trainer):
""" """
gen_kwargs = gen_kwargs.copy() gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = ( if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length gen_kwargs["max_length"] = self.args.generation_max_length
)
gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
) )
...@@ -126,9 +125,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -126,9 +125,8 @@ class Seq2SeqTrainer(Trainer):
""" """
gen_kwargs = gen_kwargs.copy() gen_kwargs = gen_kwargs.copy()
gen_kwargs["max_length"] = ( if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length gen_kwargs["max_length"] = self.args.generation_max_length
)
gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
) )
...@@ -174,9 +172,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -174,9 +172,8 @@ class Seq2SeqTrainer(Trainer):
# XXX: adapt synced_gpus for fairscale as well # XXX: adapt synced_gpus for fairscale as well
gen_kwargs = self._gen_kwargs.copy() gen_kwargs = self._gen_kwargs.copy()
gen_kwargs["max_length"] = ( if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length gen_kwargs["max_length"] = self.model.config.max_length
)
gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
) )
...@@ -203,8 +200,12 @@ class Seq2SeqTrainer(Trainer): ...@@ -203,8 +200,12 @@ class Seq2SeqTrainer(Trainer):
**gen_kwargs, **gen_kwargs,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]: if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
with torch.no_grad(): with torch.no_grad():
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
...@@ -222,8 +223,12 @@ class Seq2SeqTrainer(Trainer): ...@@ -222,8 +223,12 @@ class Seq2SeqTrainer(Trainer):
if has_labels: if has_labels:
labels = inputs["labels"] labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]: if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
else: else:
labels = None labels = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment