Unverified Commit 4c0e251d authored by Gaurav Kumbhat's avatar Gaurav Kumbhat Committed by GitHub
Browse files

🐛 Handle empty gen_kwargs for seq2seq trainer prediction_step function (#24759)

* 🐛

 Handle empty gen_kwargs for seq2seq trainer prediction_step fn
Signed-off-by: default avatargkumbhat <kumbhat.gaurav@gmail.com>

* Update src/transformers/trainer_seq2seq.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Signed-off-by: default avatargkumbhat <kumbhat.gaurav@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 253d43d4
...@@ -221,6 +221,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -221,6 +221,7 @@ class Seq2SeqTrainer(Trainer):
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform an evaluation step on `model` using `inputs`. Perform an evaluation step on `model` using `inputs`.
...@@ -237,6 +238,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -237,6 +238,8 @@ class Seq2SeqTrainer(Trainer):
argument `labels`. Check your model's documentation for all accepted arguments. argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`): prediction_loss_only (`bool`):
Whether or not to return the loss only. Whether or not to return the loss only.
gen_kwargs:
Additional `generate` specific kwargs.
Return: Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
...@@ -254,7 +257,10 @@ class Seq2SeqTrainer(Trainer): ...@@ -254,7 +257,10 @@ class Seq2SeqTrainer(Trainer):
# XXX: adapt synced_gpus for fairscale as well # XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate): # Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig() # gen_kwargs > model.generation_config > default GenerationConfig()
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
gen_kwargs = self._gen_kwargs.copy() gen_kwargs = self._gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length gen_kwargs["max_length"] = self.model.config.max_length
gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment