Unverified Commit 80d712fa authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding new argument `max_new_tokens` for generate. (#11476)

* Adding new argument `max_new_tokens` for generate.

This is a proposal to add a new argument `max_new_tokens` to `generate`.
This include a `MaxNewTokensCriteria` that enables callers that don't
know about the token length ahead (like pipelines callers) to manage
more easily the length of their generated output.

* Adding a test for the user warning when both`max_length` and
`max_new_tokens` are used together.

* Removed redundant `no_grad`.
parent 2dd6fb25
......@@ -57,6 +57,29 @@ class MaxLengthCriteria(StoppingCriteria):
return input_ids.shape[-1] >= self.max_length
class MaxNewTokensCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`.
Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is
very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens.
Args:
start_length (:obj:`int`):
The number of initial tokens.
max_new_tokens (:obj:`int`):
The maximum number of tokens to generate.
"""
def __init__(self, start_length: int, max_new_tokens: int):
self.start_length = start_length
self.max_new_tokens = max_new_tokens
self.max_length = start_length + max_new_tokens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length
class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
......@@ -89,6 +112,8 @@ class StoppingCriteriaList(list):
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
return stopping_criterium.max_length
return None
......
......@@ -42,6 +42,7 @@ from .generation_logits_process import (
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
......@@ -628,15 +629,15 @@ class GenerationMixin:
return processors
def _get_stopping_criteria(
self,
max_length: Optional[int],
max_time: Optional[float],
self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int
) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
if max_length is not None:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None:
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
if max_new_tokens is not None:
stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens))
return stopping_criteria
@torch.no_grad()
......@@ -661,6 +662,7 @@ class GenerationMixin:
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
......@@ -692,8 +694,11 @@ class GenerationMixin:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
max_length (:obj:`int`, `optional`, defaults to 20):
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
The maximum length of the sequence to be generated.
max_new_tokens (:obj:`int`, `optional`, defaults to None):
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
:obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose.
min_length (:obj:`int`, `optional`, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
......@@ -861,6 +866,15 @@ class GenerationMixin:
"""
# set init values
if max_length is None and max_new_tokens is None:
# Both are None, default
max_length = self.config.max_length
elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning
warnings.warn(
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning
)
max_length = max_length if max_length is not None else self.config.max_length
num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
......@@ -960,7 +974,10 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values,
)
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
cur_len = input_ids.shape[-1]
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len
)
if is_greedy_gen_mode:
if num_return_sequences > 1:
......
......@@ -12,6 +12,7 @@ if is_torch_available():
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
......@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10)
def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5)
......
......@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase):
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_max_new_tokens(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 15])
# Encoder decoder call
max_new_tokens = 3
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 15 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 18])
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment