Unverified Commit 8f2f0f0f authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Track each row separately for stopping criteria (#29116)

parent ece1b62b
...@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" ...@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Additional stopping criteria specific kwargs. Additional stopping criteria specific kwargs.
Return: Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop. `torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
""" """
...@@ -42,7 +43,7 @@ class StoppingCriteria(ABC): ...@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
""" """
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
raise NotImplementedError("StoppingCriteria needs to be subclassed") raise NotImplementedError("StoppingCriteria needs to be subclassed")
...@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria): ...@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
...@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria): ...@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all." "exceptions, performance degradation, or nothing at all."
) )
return is_done return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class MaxNewTokensCriteria(StoppingCriteria): class MaxNewTokensCriteria(StoppingCriteria):
...@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria): ...@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
self.max_length = start_length + max_new_tokens self.max_length = start_length + max_new_tokens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
return input_ids.shape[-1] >= self.max_length is_done = input_ids.shape[-1] >= self.max_length
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class MaxTimeCriteria(StoppingCriteria): class MaxTimeCriteria(StoppingCriteria):
...@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria): ...@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
return time.time() - self.initial_timestamp > self.max_time is_done = time.time() - self.initial_timestamp > self.max_time
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class StoppingCriteriaList(list): class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
return any(criteria(input_ids, scores, **kwargs) for criteria in self) is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
for criteria in self:
is_done = is_done | criteria(input_ids, scores, **kwargs)
return is_done
@property @property
def max_length(self) -> Optional[int]: def max_length(self) -> Optional[int]:
......
...@@ -2195,11 +2195,9 @@ class GenerationMixin: ...@@ -2195,11 +2195,9 @@ class GenerationMixin:
) )
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = True
# stop if we exceed the maximum length if unfinished_sequences.max() == 0:
if stopping_criteria(input_ids, scores):
this_peer_finished = True this_peer_finished = True
if this_peer_finished and not synced_gpus: if this_peer_finished and not synced_gpus:
...@@ -2478,14 +2476,12 @@ class GenerationMixin: ...@@ -2478,14 +2476,12 @@ class GenerationMixin:
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus: if this_peer_finished and not synced_gpus:
break break
...@@ -2772,14 +2768,12 @@ class GenerationMixin: ...@@ -2772,14 +2768,12 @@ class GenerationMixin:
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus: if this_peer_finished and not synced_gpus:
break break
...@@ -3169,7 +3163,7 @@ class GenerationMixin: ...@@ -3169,7 +3163,7 @@ class GenerationMixin:
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus: if not synced_gpus:
break break
else: else:
...@@ -3516,7 +3510,7 @@ class GenerationMixin: ...@@ -3516,7 +3510,7 @@ class GenerationMixin:
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus: if not synced_gpus:
break break
else: else:
...@@ -3912,7 +3906,7 @@ class GenerationMixin: ...@@ -3912,7 +3906,7 @@ class GenerationMixin:
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus: if not synced_gpus:
break break
else: else:
...@@ -4267,7 +4261,7 @@ class GenerationMixin: ...@@ -4267,7 +4261,7 @@ class GenerationMixin:
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus: if not synced_gpus:
break break
else: else:
...@@ -4657,14 +4651,12 @@ class GenerationMixin: ...@@ -4657,14 +4651,12 @@ class GenerationMixin:
.prod(dim=0) .prod(dim=0)
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus: if this_peer_finished and not synced_gpus:
break break
......
...@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
] ]
) )
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9) input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(all(criteria(input_ids, scores)))
def test_max_length_criteria(self): def test_max_length_criteria(self):
criteria = MaxLengthCriteria(max_length=10) criteria = MaxLengthCriteria(max_length=10)
input_ids, scores = self._get_tensors(5) input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9) input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(all(criteria(input_ids, scores)))
def test_max_new_tokens_criteria(self): def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5) input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9) input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(all(criteria(input_ids, scores)))
criteria_list = StoppingCriteriaList([criteria]) criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10) self.assertEqual(criteria_list.max_length, 10)
...@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(5) input_ids, scores = self._get_tensors(5)
criteria = MaxTimeCriteria(max_time=0.1) criteria = MaxTimeCriteria(max_time=0.1)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(all(criteria(input_ids, scores)))
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(all(criteria(input_ids, scores)))
def test_validate_stopping_criteria(self): def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment