Unverified Commit f21af262 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)

parent ea55bd86
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer): ...@@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer):
Args: Args:
batch_size (`int`): batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel. Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`): num_beams (`int`):
Number of beams for beam search. Number of beams for beam search.
device (`torch.device`): device (`torch.device`):
...@@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer): ...@@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences. `length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`): do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`]. [`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`): num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
""" """
def __init__( def __init__(
...@@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer): ...@@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer):
num_beams: int, num_beams: int,
device: torch.device, device: torch.device,
length_penalty: Optional[float] = 1.0, length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False, do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1, num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1, num_beam_groups: Optional[int] = 1,
**kwargs, max_length: Optional[int] = None,
): ):
self.num_beams = num_beams self.num_beams = num_beams
self.device = device self.device = device
...@@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer):
num_beams=self.num_beams, num_beams=self.num_beams,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping, early_stopping=self.do_early_stopping,
max_length=max_length,
) )
for _ in range(batch_size) for _ in range(batch_size)
] ]
...@@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer): ...@@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer):
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
) )
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property @property
def is_done(self) -> bool: def is_done(self) -> bool:
return self._done.all() return self._done.all()
...@@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
Args: Args:
batch_size (`int`): batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel. Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`): num_beams (`int`):
Number of beams for beam search. Number of beams for beam search.
constraints (`List[Constraint]`): constraints (`List[Constraint]`):
...@@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences. `length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`): do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`]. [`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`): num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
""" """
def __init__( def __init__(
...@@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
constraints: List[Constraint], constraints: List[Constraint],
device: torch.device, device: torch.device,
length_penalty: Optional[float] = 1.0, length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False, do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1, num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1, num_beam_groups: Optional[int] = 1,
**kwargs, max_length: Optional[int] = None,
): ):
self.num_beams = num_beams self.num_beams = num_beams
self.device = device self.device = device
...@@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
num_beams=self.num_beams, num_beams=self.num_beams,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping, early_stopping=self.do_early_stopping,
max_length=max_length,
) )
for _ in range(batch_size) for _ in range(batch_size)
] ]
...@@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
) )
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property @property
def is_done(self) -> bool: def is_done(self) -> bool:
return self._done.all() return self._done.all()
...@@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
class BeamHypotheses: class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
""" """
Initialize n-best list of hypotheses. Initialize n-best list of hypotheses.
""" """
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
self.max_length = max_length
self.num_beams = num_beams self.num_beams = num_beams
self.beams = [] self.beams = []
self.worst_score = 1e9 self.worst_score = 1e9
if not isinstance(self.early_stopping, bool) and self.max_length is None:
raise ValueError(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def __len__(self): def __len__(self):
""" """
Number of hypotheses in the list. Number of hypotheses in the list.
...@@ -903,9 +905,26 @@ class BeamHypotheses: ...@@ -903,9 +905,26 @@ class BeamHypotheses:
if len(self) < self.num_beams: if len(self) < self.num_beams:
return False return False
elif self.early_stopping:
# `True`: stop as soon as at least `num_beams` hypotheses are finished
if self.early_stopping is True:
return True return True
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
else:
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else: else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= cur_score ret = self.worst_score >= highest_attainable_score
return ret return ret
...@@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin): ...@@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin):
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
min_new_tokens (`int`, *optional*): min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`): early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
max_time(`float`, *optional*): max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed. the current pass after allocated time has been passed.
...@@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin):
logger.error(f"Can't set {key} with value {value} for {self}") logger.error(f"Can't set {key} with value {value} for {self}")
raise err raise err
# Validate the values of the attributes
self.validate()
def __eq__(self, other): def __eq__(self, other):
self_dict = self.__dict__.copy() self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy() other_dict = other.__dict__.copy()
...@@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin): ...@@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
def validate(self):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
"""
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import inspect import inspect
import warnings import warnings
from functools import partial from functools import partial
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
...@@ -275,6 +275,7 @@ class FlaxGenerationMixin: ...@@ -275,6 +275,7 @@ class FlaxGenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# set init values # set init values
...@@ -633,7 +634,7 @@ class FlaxGenerationMixin: ...@@ -633,7 +634,7 @@ class FlaxGenerationMixin:
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None, early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
...@@ -733,14 +734,22 @@ class FlaxGenerationMixin: ...@@ -733,14 +734,22 @@ class FlaxGenerationMixin:
not_max_length_yet = state.cur_len < max_length not_max_length_yet = state.cur_len < max_length
# 2. can the new beams still improve? # 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = state.running_scores[:, :1] / (max_length**length_penalty)
else:
best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty)
worst_finished_score = jnp.where( worst_finished_score = jnp.where(
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
) )
improvement_still_possible = jnp.all(worst_finished_score < best_running_score) improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
# 3. is there still a beam that has not finished? # 3. is there still a beam that has not finished?
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True))
return not_max_length_yet & still_open_beam & improvement_still_possible return not_max_length_yet & still_open_beam & improvement_still_possible
...@@ -813,7 +822,7 @@ class FlaxGenerationMixin: ...@@ -813,7 +822,7 @@ class FlaxGenerationMixin:
# 5. Get running sequences scores for next # 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs # Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams). # and gather top k beams (from top 2*k beams).
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences, next_running_scores = gather_beams( next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
) )
...@@ -824,10 +833,9 @@ class FlaxGenerationMixin: ...@@ -824,10 +833,9 @@ class FlaxGenerationMixin:
# - make sure no scores can be added anymore if beam is full # - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam # - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
beams_in_batch_are_full = ( beams_in_batch_are_full = jnp.broadcast_to(
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
& early_stopping ) & (early_stopping is True)
)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7) topk_log_probs += add_penalty * np.array(-1.0e7)
...@@ -838,7 +846,7 @@ class FlaxGenerationMixin: ...@@ -838,7 +846,7 @@ class FlaxGenerationMixin:
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1]
next_sequences, next_scores, next_is_sent_finished = gather_beams( next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
) )
...@@ -877,7 +885,7 @@ class FlaxGenerationMixin: ...@@ -877,7 +885,7 @@ class FlaxGenerationMixin:
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
# take best beam for each batch # take best beam for each batch
sequences = sequences[:, -1] sequences = sequences[:, 0]
scores = scores[:, -1] scores = scores[:, 0]
return FlaxBeamSearchOutput(sequences=sequences, scores=scores) return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
...@@ -611,6 +611,7 @@ class TFGenerationMixin: ...@@ -611,6 +611,7 @@ class TFGenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
...@@ -1808,7 +1809,7 @@ class TFGenerationMixin: ...@@ -1808,7 +1809,7 @@ class TFGenerationMixin:
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None, early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[TFLogitsProcessorList] = None, logits_processor: Optional[TFLogitsProcessorList] = None,
logits_warper: Optional[TFLogitsProcessorList] = None, logits_warper: Optional[TFLogitsProcessorList] = None,
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
...@@ -1838,8 +1839,12 @@ class TFGenerationMixin: ...@@ -1838,8 +1839,12 @@ class TFGenerationMixin:
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences. while `length_penalty` < 0.0 encourages shorter sequences.
early_stopping (`bool`, *optional*, defaults to `False`): early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following
values: `True`, where the generation stops as soon as there are `num_beams` complete candidates;
`False`, where an heuristic is applied and the generation stops when is it very unlikely to find better
candidates; `"never"`, where the beam search procedure only stops when there cannot be better
candidates (canonical beam search algorithm).
logits_processor (`[TFLogitsProcessorList]`, *optional*): logits_processor (`[TFLogitsProcessorList]`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step. used to modify the prediction scores of the language modeling head applied at each generation step.
...@@ -2009,16 +2014,24 @@ class TFGenerationMixin: ...@@ -2009,16 +2014,24 @@ class TFGenerationMixin:
not_max_length_yet = cur_len < max_length not_max_length_yet = cur_len < max_length
# 2. can the new beams still improve? # 2. can the new beams still improve?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty) best_running_score = running_scores[:, :1] / (max_length**length_penalty)
else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
worst_finished_score = tf.where( worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
) )
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score)
# 3. is there still a beam that has not finished? # 3. is there still a beam that has not finished?
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & (early_stopping is True))
return not_max_length_yet & (still_open_beam | improvement_still_possible) return not_max_length_yet & still_open_beam & improvement_still_possible
def beam_search_body_fn( def beam_search_body_fn(
cur_len, cur_len,
...@@ -2140,12 +2153,9 @@ class TFGenerationMixin: ...@@ -2140,12 +2153,9 @@ class TFGenerationMixin:
# - make sure no scores can be added anymore if beam is full # - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam # - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = ( beams_in_batch_are_full = tf.broadcast_to(
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) ) & (early_stopping is True)
& early_stopping
)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9 topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9
...@@ -2239,7 +2249,7 @@ class TFGenerationMixin: ...@@ -2239,7 +2249,7 @@ class TFGenerationMixin:
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences) sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
scores = tf.where(none_finished[:, None], scores, running_scores) scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in ascending order) # Take best beams for each batch (the score is sorted in descending order)
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences]) scores = flatten_beam_dim(scores[:, :num_return_sequences])
......
...@@ -1190,6 +1190,7 @@ class GenerationMixin: ...@@ -1190,6 +1190,7 @@ class GenerationMixin:
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
...@@ -1458,6 +1459,7 @@ class GenerationMixin: ...@@ -1458,6 +1459,7 @@ class GenerationMixin:
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
) )
# 12. interleave input_ids with `num_beams` additional sequences per batch # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -1493,6 +1495,7 @@ class GenerationMixin: ...@@ -1493,6 +1495,7 @@ class GenerationMixin:
device=inputs_tensor.device, device=inputs_tensor.device,
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
max_length=generation_config.max_length,
) )
# 13. interleave input_ids with `num_beams` additional sequences per batch # 13. interleave input_ids with `num_beams` additional sequences per batch
...@@ -1536,12 +1539,12 @@ class GenerationMixin: ...@@ -1536,12 +1539,12 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_beams=generation_config.num_beams, num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device, device=inputs_tensor.device,
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups, num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
) )
# 12. interleave input_ids with `num_beams` additional sequences per batch # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
...@@ -1629,6 +1632,7 @@ class GenerationMixin: ...@@ -1629,6 +1632,7 @@ class GenerationMixin:
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
) )
# 12. interleave input_ids with `num_beams` additional sequences per batch # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
......
...@@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
) )
return self.beam_search( return self.beam_search(
input_ids, input_ids,
......
...@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
**model_kwargs, **model_kwargs,
) )
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids = input_ids[:, :1]
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_custom_stopping_criteria_overload_error(self): def test_custom_stopping_criteria_overload_error(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
......
...@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT ...@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
) )
input_ids = tokenizer(input_str, return_tensors="np").input_ids input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences
output_str = tokenizer.batch_decode(sequences)[0] output_str = tokenizer.batch_decode(sequences)[0]
......
...@@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes ...@@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [ expected_string = [
"Hello this is a long string of words. I'm going to try to explain what I mean.", "Hello this is a long string of words. I'm going to start with the first one.\n",
"Hey, I'm not sure if I'm going to be able to do", "Hey, I'm not sure if I'm going to be able to do",
] ]
......
...@@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase): ...@@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
expected_summaries = [ expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", " magazine says . all 150 on board were killed in the crash .",
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the" " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .", " court, Palestinians may be subject to counter-charges as well .",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment