Unverified Commit f8e90d6f authored by Jannis Born's avatar Jannis Born Committed by GitHub
Browse files

Fix typing error in Trainer class (prediction_step) (#11138)

* fix: docstrings in prediction_step

* ci: Satisfy line length requirements

* ci: character length requirements
parent ffe07617
...@@ -1966,7 +1966,7 @@ class Trainer: ...@@ -1966,7 +1966,7 @@ class Trainer:
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform an evaluation step on :obj:`model` using obj:`inputs`. Perform an evaluation step on :obj:`model` using obj:`inputs`.
...@@ -1987,8 +1987,8 @@ class Trainer: ...@@ -1987,8 +1987,8 @@ class Trainer:
gathering predictions. gathering predictions.
Return: Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
labels (each being optional). logits and labels (each being optional).
""" """
has_labels = all(inputs.get(k) is not None for k in self.label_names) has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment