"tests/vscode:/vscode.git/clone" did not exist on "f26e4073707189c93915227779a4f6ea3c40d43b"
Unverified Commit aad95c7c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removed `max_length` from being mandatory within `generate`. (#11314)

* Removed `max_length` from being mandatory within `generate`.

- Moving on to fully using `StoppingCriteria` for `greedy` and `sample`
modes.
- `max_length` still used for `beam_search` and `group_beam_search`
(Follow up PR)
- Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a
we hit the max_length, the comparison needs to be or equal, that affects
the tests).
- Added options to use `logits_processor` and `stopping_criteria`
directly within `generate` function (so some users can define their own
`logits_processor` and `stopping_criteria`).
- Modified the backward compat tests to make sure we issue a warning.

* Fix `max_length` argument in `generate`.

* Moving validate to being functional.

- Renamed `smax_length` to `stoppping_max_length`.

* Removing `logits_processor` and `stopping_criteria` from `generate`
arguments.

* Deepcopy.

* Fix global variable name.
parent 95dab34d
import time
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
import torch
......@@ -8,7 +9,7 @@ import torch
from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
......@@ -33,7 +34,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
......@@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria):
def __init__(self, max_length: int):
self.max_length = max_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length
return input_ids.shape[-1] >= self.max_length
class MaxTimeCriteria(StoppingCriteria):
......@@ -73,25 +74,29 @@ class MaxTimeCriteria(StoppingCriteria):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
found = False
for stopping_criterium in stopping_criteria:
@property
def max_length(self) -> Optional[int]:
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return stopping_criterium.max_length
return None
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
......@@ -564,6 +565,7 @@ class GenerationMixin:
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
"""
processors = LogitsProcessorList()
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
......@@ -589,7 +591,6 @@ class GenerationMixin:
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
# instantiate processors list
processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
......@@ -629,7 +630,6 @@ class GenerationMixin:
max_length: Optional[int],
max_time: Optional[float],
) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
if max_length is not None:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
......@@ -859,9 +859,9 @@ class GenerationMixin:
"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
......@@ -958,10 +958,13 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values,
)
stopping_criteria = self._get_stopping_criteria(
max_length=max_length,
max_time=max_time,
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if is_greedy_gen_mode:
if num_return_sequences > 1:
......@@ -974,7 +977,6 @@ class GenerationMixin:
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
......@@ -1003,7 +1005,6 @@ class GenerationMixin:
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
......@@ -1021,9 +1022,12 @@ class GenerationMixin:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
......@@ -1039,7 +1043,6 @@ class GenerationMixin:
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
......@@ -1056,9 +1059,11 @@ class GenerationMixin:
batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
max_length=stopping_criteria.max_length,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
......@@ -1079,7 +1084,6 @@ class GenerationMixin:
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
......@@ -1100,10 +1104,13 @@ class GenerationMixin:
if num_beams % num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
......@@ -1119,7 +1126,6 @@ class GenerationMixin:
diverse_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
......@@ -1160,7 +1166,8 @@ class GenerationMixin:
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
......@@ -1220,8 +1227,12 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
......@@ -1251,7 +1262,7 @@ class GenerationMixin:
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
......@@ -1384,7 +1395,8 @@ class GenerationMixin:
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
......@@ -1452,8 +1464,12 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
......@@ -1485,7 +1501,7 @@ class GenerationMixin:
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while cur_len < max_length:
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
......@@ -1620,7 +1636,8 @@ class GenerationMixin:
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
......@@ -1700,8 +1717,14 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
......@@ -1740,7 +1763,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
......@@ -1770,7 +1793,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
......@@ -1907,7 +1930,8 @@ class GenerationMixin:
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
......@@ -1994,7 +2018,12 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
......@@ -2028,7 +2057,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
......@@ -2058,7 +2087,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
......@@ -2195,7 +2224,8 @@ class GenerationMixin:
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
**DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
......@@ -2279,8 +2309,12 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
......@@ -2324,7 +2358,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
......@@ -2378,7 +2412,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
......
......@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self):
......@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_time_criteria(self):
......@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = StoppingCriteriaList()
validate_stopping_criteria(stopping_criteria, 11)
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
self.assertEqual(len(stopping_criteria), 1)
......@@ -1358,6 +1358,7 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id,
)
with self.assertWarns(UserWarning):
bart_model.greedy_search(
input_ids,
max_length=max_length,
......@@ -1381,6 +1382,7 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id,
)
with torch.no_grad():
with self.assertWarns(UserWarning):
bart_model.sample(
input_ids,
max_length=max_length,
......@@ -1413,6 +1415,7 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beams=num_beams,
device=torch_device,
)
with self.assertWarns(UserWarning):
_ = bart_model.beam_search(
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
......@@ -1445,6 +1448,7 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
with self.assertWarns(UserWarning):
bart_model.group_beam_search(
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment