Unverified Commit bf78f523 authored by Will Frey's avatar Will Frey Committed by GitHub
Browse files

Fix StoppingCriteria ABC signature (#12918)

Change `score` -> `scores` because the argument is not positional-only, so you need consistently named parameters for the subclasses. The subclasses appear to favor `scores` over `score`.
parent 63f2b9ab
...@@ -35,7 +35,7 @@ class StoppingCriteria(ABC): ...@@ -35,7 +35,7 @@ class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation.""" """Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed") raise NotImplementedError("StoppingCriteria needs to be subclassed")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment