Unverified Commit aad95c7c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removed `max_length` from being mandatory within `generate`. (#11314)

* Removed `max_length` from being mandatory within `generate`.

- Moving on to fully using `StoppingCriteria` for `greedy` and `sample`
modes.
- `max_length` still used for `beam_search` and `group_beam_search`
(Follow up PR)
- Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a
we hit the max_length, the comparison needs to be or equal, that affects
the tests).
- Added options to use `logits_processor` and `stopping_criteria`
directly within `generate` function (so some users can define their own
`logits_processor` and `stopping_criteria`).
- Modified the backward compat tests to make sure we issue a warning.

* Fix `max_length` argument in `generate`.

* Moving validate to being functional.

- Renamed `smax_length` to `stoppping_max_length`.

* Removing `logits_processor` and `stopping_criteria` from `generate`
arguments.

* Deepcopy.

* Fix global variable name.
parent 95dab34d
import time import time
import warnings import warnings
from abc import ABC from abc import ABC
from copy import deepcopy
from typing import Optional from typing import Optional
import torch import torch
...@@ -8,7 +9,7 @@ import torch ...@@ -8,7 +9,7 @@ import torch
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
...@@ -33,7 +34,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -33,7 +34,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
class StoppingCriteria(ABC): class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation.""" """Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed") raise NotImplementedError("StoppingCriteria needs to be subclassed")
...@@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria): ...@@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria):
def __init__(self, max_length: int): def __init__(self, max_length: int):
self.max_length = max_length self.max_length = max_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length return input_ids.shape[-1] >= self.max_length
class MaxTimeCriteria(StoppingCriteria): class MaxTimeCriteria(StoppingCriteria):
...@@ -73,25 +74,29 @@ class MaxTimeCriteria(StoppingCriteria): ...@@ -73,25 +74,29 @@ class MaxTimeCriteria(StoppingCriteria):
self.max_time = max_time self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time return time.time() - self.initial_timestamp > self.max_time
class StoppingCriteriaList(list): class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self) return any(criteria(input_ids, scores) for criteria in self)
@property
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int): def max_length(self) -> Optional[int]:
found = False for stopping_criterium in self:
for stopping_criterium in stopping_criteria: if isinstance(stopping_criterium, MaxLengthCriteria):
if isinstance(stopping_criterium, MaxLengthCriteria): return stopping_criterium.max_length
found = True return None
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
) stopping_max_length = stopping_criteria.max_length
if not found: new_stopping_criteria = deepcopy(stopping_criteria)
stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
...@@ -564,6 +565,7 @@ class GenerationMixin: ...@@ -564,6 +565,7 @@ class GenerationMixin:
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
""" """
processors = LogitsProcessorList()
# init warp parameters # init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
...@@ -589,7 +591,6 @@ class GenerationMixin: ...@@ -589,7 +591,6 @@ class GenerationMixin:
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
) )
# instantiate processors list # instantiate processors list
processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py` # all samplers can be found in `generation_utils_samplers.py`
...@@ -629,7 +630,6 @@ class GenerationMixin: ...@@ -629,7 +630,6 @@ class GenerationMixin:
max_length: Optional[int], max_length: Optional[int],
max_time: Optional[float], max_time: Optional[float],
) -> StoppingCriteriaList: ) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList() stopping_criteria = StoppingCriteriaList()
if max_length is not None: if max_length is not None:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
...@@ -859,9 +859,9 @@ class GenerationMixin: ...@@ -859,9 +859,9 @@ class GenerationMixin:
""" """
# set init values # set init values
max_length = max_length if max_length is not None else self.config.max_length
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
...@@ -958,10 +958,13 @@ class GenerationMixin: ...@@ -958,10 +958,13 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
) )
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
max_length=max_length, if max_length is not None:
max_time=max_time, warnings.warn(
) "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if is_greedy_gen_mode: if is_greedy_gen_mode:
if num_return_sequences > 1: if num_return_sequences > 1:
...@@ -974,7 +977,6 @@ class GenerationMixin: ...@@ -974,7 +977,6 @@ class GenerationMixin:
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
output_scores=output_scores, output_scores=output_scores,
...@@ -1003,7 +1005,6 @@ class GenerationMixin: ...@@ -1003,7 +1005,6 @@ class GenerationMixin:
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
output_scores=output_scores, output_scores=output_scores,
...@@ -1021,9 +1022,12 @@ class GenerationMixin: ...@@ -1021,9 +1022,12 @@ class GenerationMixin:
if num_return_sequences > num_beams: if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length, max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -1039,7 +1043,6 @@ class GenerationMixin: ...@@ -1039,7 +1043,6 @@ class GenerationMixin:
beam_scorer, beam_scorer,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
output_scores=output_scores, output_scores=output_scores,
...@@ -1056,9 +1059,11 @@ class GenerationMixin: ...@@ -1056,9 +1059,11 @@ class GenerationMixin:
batch_size = input_ids.shape[0] * num_return_sequences batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length, max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -1079,7 +1084,6 @@ class GenerationMixin: ...@@ -1079,7 +1084,6 @@ class GenerationMixin:
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
output_scores=output_scores, output_scores=output_scores,
...@@ -1100,10 +1104,13 @@ class GenerationMixin: ...@@ -1100,10 +1104,13 @@ class GenerationMixin:
if num_beams % num_beam_groups != 0: if num_beams % num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
diverse_beam_scorer = BeamSearchScorer( diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
max_length=stopping_criteria.max_length,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
...@@ -1119,7 +1126,6 @@ class GenerationMixin: ...@@ -1119,7 +1126,6 @@ class GenerationMixin:
diverse_beam_scorer, diverse_beam_scorer,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
output_scores=output_scores, output_scores=output_scores,
...@@ -1160,7 +1166,8 @@ class GenerationMixin: ...@@ -1160,7 +1166,8 @@ class GenerationMixin:
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`): pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token. The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
...@@ -1220,8 +1227,12 @@ class GenerationMixin: ...@@ -1220,8 +1227,12 @@ class GenerationMixin:
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length if max_length is not None:
validate_stopping_criteria(stopping_criteria, max_length) warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -1251,7 +1262,7 @@ class GenerationMixin: ...@@ -1251,7 +1262,7 @@ class GenerationMixin:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while cur_len < max_length: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -1384,7 +1395,8 @@ class GenerationMixin: ...@@ -1384,7 +1395,8 @@ class GenerationMixin:
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step. modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`): pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token. The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
...@@ -1452,8 +1464,12 @@ class GenerationMixin: ...@@ -1452,8 +1464,12 @@ class GenerationMixin:
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length if max_length is not None:
validate_stopping_criteria(stopping_criteria, max_length) warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
...@@ -1485,7 +1501,7 @@ class GenerationMixin: ...@@ -1485,7 +1501,7 @@ class GenerationMixin:
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
# auto-regressive generation # auto-regressive generation
while cur_len < max_length: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -1620,7 +1636,8 @@ class GenerationMixin: ...@@ -1620,7 +1636,8 @@ class GenerationMixin:
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`): pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token. The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
...@@ -1700,8 +1717,14 @@ class GenerationMixin: ...@@ -1700,8 +1717,14 @@ class GenerationMixin:
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length if max_length is not None:
validate_stopping_criteria(stopping_criteria, max_length) warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -1740,7 +1763,7 @@ class GenerationMixin: ...@@ -1740,7 +1763,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while cur_len < max_length: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -1770,7 +1793,7 @@ class GenerationMixin: ...@@ -1770,7 +1793,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=None
) )
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -1907,7 +1930,8 @@ class GenerationMixin: ...@@ -1907,7 +1930,8 @@ class GenerationMixin:
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
modeling head applied before multinomial sampling at each generation step. modeling head applied before multinomial sampling at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`): pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token. The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
...@@ -1994,7 +2018,12 @@ class GenerationMixin: ...@@ -1994,7 +2018,12 @@ class GenerationMixin:
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -2028,7 +2057,7 @@ class GenerationMixin: ...@@ -2028,7 +2057,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while cur_len < max_length: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -2058,7 +2087,7 @@ class GenerationMixin: ...@@ -2058,7 +2087,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=None
) )
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -2195,7 +2224,8 @@ class GenerationMixin: ...@@ -2195,7 +2224,8 @@ class GenerationMixin:
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of
generated tokens. The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`): pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token. The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
...@@ -2279,8 +2309,12 @@ class GenerationMixin: ...@@ -2279,8 +2309,12 @@ class GenerationMixin:
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length if max_length is not None:
validate_stopping_criteria(stopping_criteria, max_length) warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -2324,7 +2358,7 @@ class GenerationMixin: ...@@ -2324,7 +2358,7 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while cur_len < max_length: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -2378,7 +2412,7 @@ class GenerationMixin: ...@@ -2378,7 +2412,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=None
) )
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
......
...@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -40,10 +40,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self): def test_max_length_criteria(self):
...@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -52,10 +52,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(5) input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10) input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores)) self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11) input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores)) self.assertTrue(criteria(input_ids, scores))
def test_max_time_criteria(self): def test_max_time_criteria(self):
...@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -73,7 +73,6 @@ class StoppingCriteriaTestCase(unittest.TestCase):
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11) validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = StoppingCriteriaList() stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
validate_stopping_criteria(stopping_criteria, 11)
self.assertEqual(len(stopping_criteria), 1) self.assertEqual(len(stopping_criteria), 1)
...@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1358,13 +1358,14 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id, bos_token_id=bart_model.config.bos_token_id,
) )
bart_model.greedy_search( with self.assertWarns(UserWarning):
input_ids, bart_model.greedy_search(
max_length=max_length, input_ids,
pad_token_id=bart_model.config.pad_token_id, max_length=max_length,
eos_token_id=bart_model.config.eos_token_id, pad_token_id=bart_model.config.pad_token_id,
**model_kwargs, eos_token_id=bart_model.config.eos_token_id,
) **model_kwargs,
)
def test_max_length_backward_compat_sample(self): def test_max_length_backward_compat_sample(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
...@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1381,13 +1382,14 @@ class GenerationIntegrationTests(unittest.TestCase):
bos_token_id=bart_model.config.bos_token_id, bos_token_id=bart_model.config.bos_token_id,
) )
with torch.no_grad(): with torch.no_grad():
bart_model.sample( with self.assertWarns(UserWarning):
input_ids, bart_model.sample(
max_length=max_length, input_ids,
pad_token_id=bart_model.config.pad_token_id, max_length=max_length,
eos_token_id=bart_model.config.eos_token_id, pad_token_id=bart_model.config.pad_token_id,
**model_kwargs, eos_token_id=bart_model.config.eos_token_id,
) **model_kwargs,
)
def test_max_length_backward_compat_beam_search(self): def test_max_length_backward_compat_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
...@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1413,9 +1415,10 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
) )
_ = bart_model.beam_search( with self.assertWarns(UserWarning):
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs _ = bart_model.beam_search(
) input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
def test_max_length_backward_compat_group_beam_search(self): def test_max_length_backward_compat_group_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
...@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1445,9 +1448,10 @@ class GenerationIntegrationTests(unittest.TestCase):
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
) )
bart_model.group_beam_search( with self.assertWarns(UserWarning):
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs bart_model.group_beam_search(
) input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
def test_max_length_warning_if_different(self): def test_max_length_warning_if_different(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment