Unverified Commit cc309fd4 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

pass kwargs in stopping criteria list (#28927)

parent 0b693e90
......@@ -129,7 +129,7 @@ class MaxTimeCriteria(StoppingCriteria):
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
@property
def max_length(self) -> Optional[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment