Unverified Commit 80d712fa authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding new argument `max_new_tokens` for generate. (#11476)

* Adding new argument `max_new_tokens` for generate.

This is a proposal to add a new argument `max_new_tokens` to `generate`.
This include a `MaxNewTokensCriteria` that enables callers that don't
know about the token length ahead (like pipelines callers) to manage
more easily the length of their generated output.

* Adding a test for the user warning when both`max_length` and
`max_new_tokens` are used together.

* Removed redundant `no_grad`.
parent 2dd6fb25
...@@ -57,6 +57,29 @@ class MaxLengthCriteria(StoppingCriteria): ...@@ -57,6 +57,29 @@ class MaxLengthCriteria(StoppingCriteria):
return input_ids.shape[-1] >= self.max_length return input_ids.shape[-1] >= self.max_length
class MaxNewTokensCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`.
Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is
very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens.
Args:
start_length (:obj:`int`):
The number of initial tokens.
max_new_tokens (:obj:`int`):
The maximum number of tokens to generate.
"""
def __init__(self, start_length: int, max_new_tokens: int):
self.start_length = start_length
self.max_new_tokens = max_new_tokens
self.max_length = start_length + max_new_tokens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length
class MaxTimeCriteria(StoppingCriteria): class MaxTimeCriteria(StoppingCriteria):
""" """
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
...@@ -89,6 +112,8 @@ class StoppingCriteriaList(list): ...@@ -89,6 +112,8 @@ class StoppingCriteriaList(list):
for stopping_criterium in self: for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria): if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length return stopping_criterium.max_length
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
return stopping_criterium.max_length
return None return None
......
...@@ -42,6 +42,7 @@ from .generation_logits_process import ( ...@@ -42,6 +42,7 @@ from .generation_logits_process import (
) )
from .generation_stopping_criteria import ( from .generation_stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteriaList, StoppingCriteriaList,
validate_stopping_criteria, validate_stopping_criteria,
...@@ -628,15 +629,15 @@ class GenerationMixin: ...@@ -628,15 +629,15 @@ class GenerationMixin:
return processors return processors
def _get_stopping_criteria( def _get_stopping_criteria(
self, self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int
max_length: Optional[int],
max_time: Optional[float],
) -> StoppingCriteriaList: ) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList() stopping_criteria = StoppingCriteriaList()
if max_length is not None: if max_length is not None:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None: if max_time is not None:
stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
if max_new_tokens is not None:
stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens))
return stopping_criteria return stopping_criteria
@torch.no_grad() @torch.no_grad()
...@@ -661,6 +662,7 @@ class GenerationMixin: ...@@ -661,6 +662,7 @@ class GenerationMixin:
encoder_no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None, max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None, num_beam_groups: Optional[int] = None,
...@@ -692,8 +694,11 @@ class GenerationMixin: ...@@ -692,8 +694,11 @@ class GenerationMixin:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`. :obj:`torch.LongTensor` of shape :obj:`(1,)`.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
max_new_tokens (:obj:`int`, `optional`, defaults to None):
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
:obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose.
min_length (:obj:`int`, `optional`, defaults to 10): min_length (:obj:`int`, `optional`, defaults to 10):
The minimum length of the sequence to be generated. The minimum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
...@@ -861,6 +866,15 @@ class GenerationMixin: ...@@ -861,6 +866,15 @@ class GenerationMixin:
""" """
# set init values # set init values
if max_length is None and max_new_tokens is None:
# Both are None, default
max_length = self.config.max_length
elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning
warnings.warn(
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning
)
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
...@@ -960,7 +974,10 @@ class GenerationMixin: ...@@ -960,7 +974,10 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
) )
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) cur_len = input_ids.shape[-1]
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len
)
if is_greedy_gen_mode: if is_greedy_gen_mode:
if num_return_sequences > 1: if num_return_sequences > 1:
......
...@@ -12,6 +12,7 @@ if is_torch_available(): ...@@ -12,6 +12,7 @@ if is_torch_available():
from transformers.generation_stopping_criteria import ( from transformers.generation_stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteriaList, StoppingCriteriaList,
validate_stopping_criteria, validate_stopping_criteria,
...@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(criteria(input_ids, scores))
def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10)
def test_max_time_criteria(self): def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5) input_ids, scores = self._get_tensors(5)
......
...@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase):
# BeamSearchScorer max_length should not influence "real" max_length # BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_max_new_tokens(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 15])
# Encoder decoder call
max_new_tokens = 3
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 15 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 18])
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment