Unverified Commit 2b5603f6 authored by Chan Woo Kim's avatar Chan Woo Kim Committed by GitHub
Browse files

Constrained Beam Search [without disjunctive decoding] (#15416)



* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0113aae5
...@@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation ...@@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation
[`~generation_utils.GenerationMixin.greedy_search`], [`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`], [`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], and [`~generation_utils.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`]. [`~generation_utils.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library. Most of those are only useful if you are studying the code of the generate methods in the library.
...@@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ...@@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
[[autodoc]] MaxTimeCriteria [[autodoc]] MaxTimeCriteria
- __call__ - __call__
## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] ConstraintListState
## BeamSearch ## BeamSearch
[[autodoc]] BeamScorer [[autodoc]] BeamScorer
...@@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ...@@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
- process - process
- finalize - finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Utilities ## Utilities
[[autodoc]] top_k_top_p_filtering [[autodoc]] top_k_top_p_filtering
......
...@@ -612,7 +612,12 @@ if is_torch_available(): ...@@ -612,7 +612,12 @@ if is_torch_available():
"TextDatasetForNextSentencePrediction", "TextDatasetForNextSentencePrediction",
] ]
_import_structure["deepspeed"] = [] _import_structure["deepspeed"] = []
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] _import_structure["generation_beam_constraints"] = [
"Constraint",
"ConstraintListState",
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_logits_process"] = [ _import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
...@@ -2750,7 +2755,8 @@ if TYPE_CHECKING: ...@@ -2750,7 +2755,8 @@ if TYPE_CHECKING:
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
......
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import torch
class Constraint(ABC):
r"""Abstract base class for all constraints that can be applied during generation.
It must define how the constraint can be satisfied.
All classes that inherit Constraint must follow the requirement that
```py
completed = False
while not completed:
_, completed = constraint.update(constraint.advance())
```
will always terminate (halt).
"""
def __init__(self):
# test for the above condition
self.test()
def test(self):
"""
Tests whether this constraint has been properly defined.
"""
counter = 0
completed = False
while not completed:
if counter == 1:
self.reset()
advance = self.advance()
if not self.does_advance(advance):
raise Exception(
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
)
stepped, completed, reset = self.update(advance)
counter += 1
if counter > 10000:
raise Exception("update() does not fulfill the constraint.")
if self.remaining() != 0:
raise Exception("Custom Constraint is not defined correctly.")
@abstractmethod
def advance(self):
"""
When called, returns the token that would take this constraint one step closer to being fulfilled.
Return:
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def does_advance(self, token_id: int):
"""
Reads in a token and returns whether it creates progress.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def update(self, token_id: int):
"""
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
state of this object unlikes `does_advance(self, token_id: int)`.
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
been generated. This becomes important if token_id != desired token (refer to else statement in
PhrasalConstraint)
Args:
token_id(`int`):
The id of a newly generated token in the beam search.
Return:
stepped(`bool`):
Whether this constraint has become one step closer to being fulfuilled.
completed(`bool`):
Whether this constraint has been completely fulfilled by this token being generated.
reset (`bool`):
Whether this constraint has reset its progress by this token being generated.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def reset(self):
"""
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
a constraint is abrupted by an unwanted token.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def remaining(self):
"""
Returns the number of remaining steps of `advance()` in order to complete this constraint.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def copy(self, stateful=False):
"""
Creates a new instance of this constraint.
Args:
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
Return:
constraint(`Constraint`): The same constraint as the one being called from.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class PhrasalConstraint(Constraint):
r"""
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
Args:
token_ids (`List[int]`):
The id of the token that must be generated by the output.
"""
def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
super(Constraint, self).__init__()
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
is_tensor = isinstance(token_ids, torch.Tensor)
is_int_tensor = (
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
)
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
if not is_tensor:
token_ids = torch.tensor(token_ids)
self.token_ids = token_ids
self.seqlen = self.token_ids.size(0)
self.fulfilled_idx = -1 # the index of the currently fulfilled step
self.completed = False
def advance(self):
return self.token_ids[self.fulfilled_idx + 1]
def does_advance(self, token_id: int):
if self.completed:
return False
# move to cpu to guarantee no device issues.
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()
def update(self, token_id: int):
stepped = False
completed = False
reset = False
if self.does_advance(token_id):
self.fulfilled_idx += 1
stepped = True
if self.fulfilled_idx == (self.seqlen - 1):
completed = True
self.completed = completed
else:
# failed to make progress.
reset = True
self.reset()
return stepped, completed, reset
def reset(self):
self.completed = False
self.fulfilled_idx = 0
def remaining(self):
return self.seqlen - (self.fulfilled_idx + 1)
def copy(self, stateful=False):
new_constraint = PhrasalConstraint(self.token_ids)
if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.fulfilled_idx = self.fulfilled_idx
new_constraint.completed = self.completed
return new_constraint
class ConstraintListState:
r"""
A class for beam scorers to track its progress through a list of constraints.
Args:
constraints (`List[Constraint]`):
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
"""
def __init__(self, constraints: List[Constraint]):
self.constraints = constraints
# max # of steps required to fulfill a given constraint
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
self.n_constraints = len(constraints)
self.completed = False
self.init_state()
def init_state(self):
self.complete_constraints = []
self.inprogress_constraint = None
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
def get_bank(self):
add = 0
if self.inprogress_constraint:
# extra points for having a constraint mid-fulfilled
add += self.max_seqlen - self.inprogress_constraint.remaining()
return (len(self.complete_constraints) * self.max_seqlen) + add
def advance(self):
"""The list of tokens to generate such that we can make progress.
By "list" we don't mean the list of token that will fully fulfill a constraint.
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
specific constraint `c_i`, we return:
`[t_k1 for k in indices of unfulfilled constraints]`
If we are in the middle of a constraint, then we return:
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
that's the only one we'll return.
"""
if self.inprogress_constraint is None:
token_list = []
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
advance = constraint.advance()
token_list.append(advance)
else:
token_list = [self.inprogress_constraint.advance()]
if len(token_list) == 0:
return None
else:
return torch.stack(token_list)
def reset(self, token_ids: Optional[torch.LongTensor]):
"""
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
"""
self.init_state()
if token_ids is not None and token_ids.size(0) > 0:
for token in token_ids:
# completes or steps **one** constraint
complete, stepped = self.add(token)
# the entire list of constraints are fulfilled
if self.completed:
break
return self
def add(self, token_id: Union[int, torch.LongTensor]):
complete, stepped = False, False
if self.completed:
complete = True
stepped = False
return complete, stepped
if self.inprogress_constraint is not None:
# In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
# job, simply update the state
stepped, complete, reset = self.inprogress_constraint.update(token_id)
if reset:
# 1. If the next token breaks the progress, then we must restart.
# e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
# But that doesn't mean we self.init_state(), since we only reset the state for this particular
# constraint, not the full list of constraints.
self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
self.inprogress_constraint = None
if complete:
# 2. If the next token completes the constraint, move it to completed list, set
# inprogress to None. If there are no pending constraints either, then this full list of constraints
# is complete.
self.complete_constraints.append(self.inprogress_constraint)
self.inprogress_constraint = None
if len(self.pending_constraints) == 0:
# we're done!
self.completed = True
else:
# Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
# of constraints?
for cidx, pending_constraint in enumerate(self.pending_constraints):
if pending_constraint.does_advance(token_id):
stepped, complete, reset = pending_constraint.update(token_id)
if not stepped:
raise Exception(
"constraint.update(token_id) is not yielding incremental progress, "
"even though constraint.does_advance(token_id) is true."
)
if complete:
self.complete_constraints.append(pending_constraint)
self.inprogress_constraint = None
if not complete and stepped:
self.inprogress_constraint = pending_constraint
if complete or stepped:
# If we made any progress at all, then it's at least not a "pending constraint".
self.pending_constraints = (
self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
)
if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
# If there's no longer any pending after this and no inprogress either, then we must be
# complete.
self.completed = True
break # prevent accidentally stepping through multiple constraints with just one token.
return complete, stepped
def copy(self, stateful=True):
new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
# throughout this process. So it's at initialization state.
if stateful:
new_state.complete_constraints = [
constraint.copy(stateful=True) for constraint in self.complete_constraints
]
if self.inprogress_constraint is not None:
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
return new_state
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from typing import Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
import torch import torch
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .generation_beam_constraints import Constraint, ConstraintListState
PROCESS_INPUTS_DOCSTRING = r""" PROCESS_INPUTS_DOCSTRING = r"""
...@@ -336,12 +338,462 @@ class BeamSearchScorer(BeamScorer): ...@@ -336,12 +338,462 @@ class BeamSearchScorer(BeamScorer):
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined" assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id) decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in # fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
}
)
class ConstrainedBeamSearchScorer(BeamScorer):
r"""
[`BeamScorer`] implementing constrained beam search decoding.
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
constraints (`List[Constraint]`):
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
output. For more information, the documentation of [`Constraint`] should be read.
device (`torch.device`):
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
allocated.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
"""
def __init__(
self,
batch_size: int,
num_beams: int,
constraints: List[Constraint],
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self.constraints = constraints
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
def make_constraint_states(self, n):
return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
def check_completes_constraints(self, sequence):
new_state = self.make_constraint_states(1)[0]
new_state = new_state.reset(sequence)
return new_state.completed
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
The scores of all tokens in the vocabulary for each of the beam hypotheses.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
all
non-finished beams.
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
added
to the non-finished beam_hypotheses.
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
indicating to which beam the next tokens shall be added.
"""
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence.
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx])
if completes_constraint:
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
new_scores, new_tokens, new_indices = self.step_sentence_constraint(
batch_idx,
input_ids,
scores_for_all_vocab,
next_beam_scores[batch_idx],
next_beam_tokens[batch_idx],
next_beam_indices[batch_idx],
)
next_beam_scores[batch_idx] = new_scores
next_beam_tokens[batch_idx] = new_tokens
next_beam_indices[batch_idx] = new_indices
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def step_sentence_constraint(
self,
batch_idx: int,
input_ids: torch.LongTensor,
vocab_scores: torch.FloatTensor,
sent_beam_scores: torch.FloatTensor,
sent_beam_tokens: torch.LongTensor,
sent_beam_indices: torch.LongTensor,
push_progress: bool = False,
):
# sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
# (candidate next tokens)
# 1. Adding "advance_tokens"
# using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
# advance us in fulfilling the constraints.
# 2. Selecting best candidates such that we end up with highest probable candidates
# that fulfill our constraints.
orig_len = sent_beam_indices.size(0)
device = sent_beam_indices.device
# initialize states
topk_contraint_states = self.make_constraint_states(orig_len)
advance_constraint_states = self.make_constraint_states(orig_len)
sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
this_batch_input_ids = input_ids[sidx:eidx]
this_batch_token_scores = vocab_scores[sidx:eidx]
full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
# need to make new hypothesis that advance the constraints
track_new = {"new_seqs": [], "new_states": [], "new_indices": [], "new_tokens": [], "new_scores": []}
for seq_idx, pre_seq in enumerate(this_batch_input_ids):
# pre_seq = ith sequence generated before this step.
# input_ids -> (topk) generic beam search best model next tokens
# -> (advance) constraints forcing the next token
# either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
# hypotheses.
topk_state = topk_contraint_states[seq_idx]
topk_state.reset(full_hypotheses[seq_idx])
advance_state = advance_constraint_states[seq_idx]
advance_state.reset(pre_seq)
if not advance_state.completed:
advance_tokens = advance_state.advance()
for advance_token in advance_tokens.to(device):
# since adding each `advance_token` leads to a different hypothesis, create new state instance.
new_state = advance_state.copy(stateful=True)
new_state.add(advance_token)
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
if advance_seq not in track_new["new_seqs"]:
# prevent duplicates, which are basically bound to happen in this process.
track_new["new_seqs"].append(advance_seq)
track_new["new_indices"].append(seq_idx)
track_new["new_tokens"].append(advance_token)
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
track_new["new_states"].append(new_state)
elif push_progress:
# Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
# actually fulfill our constraints. For example, let constraints == ["loves pies"] and
# pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
# Without this step, if `sent_beam_indices` is something like [1,1], then
# 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
# 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
# the else part of `if constraints_completed[seq_idx]`)
# 3. it ends up simply getting removed from consideration.
# #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
# especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
# search times, since completed sequences keep getting removed after all this effort for constrained
# generation.
# Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
# appending the next likely token in the vocabulary and adding it to the list of hypotheses.
new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
advance_state = advance_constraint_states[seq_idx]
advance_state.reset(advance_seq)
advance_seq = advance_seq.cpu().tolist()
if advance_seq not in track_new["new_seqs"]:
# but still don't want to have duplicates
track_new["new_seqs"].append(advance_seq)
track_new["new_indices"].append(seq_idx)
track_new["new_tokens"].append(new_token)
track_new["new_scores"].append(new_score)
track_new["new_states"].append(advance_state)
if len(track_new["new_indices"]) > 0:
new_indices = torch.tensor(track_new["new_indices"]).to(device)
new_tokens = torch.stack(track_new["new_tokens"]).to(device)
new_scores = torch.stack(track_new["new_scores"]).to(device)
all_states = topk_contraint_states + track_new["new_states"]
all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
all_scores = torch.cat((sent_beam_scores, new_scores), -1)
all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
zipped = all_banks * 100 + all_scores
indices = zipped.sort(descending=True).indices
sorted_banks = all_banks[indices]
# Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
counter = -1
cur_bank = sorted_banks[0]
increments = []
for bank in sorted_banks:
if bank == cur_bank:
counter += 1
else:
counter = 0
cur_bank = bank
increments.append(counter)
rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
indices = indices[rearrangers][:orig_len]
sent_beam_scores = all_scores[indices]
sent_beam_tokens = all_tokens[indices]
sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
return sent_beam_scores, sent_beam_tokens, sent_beam_indices
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
ids_collect = []
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
completes_constraint = self.check_completes_constraints(final_tokens)
if completes_constraint:
beam_hyp.add(final_tokens, final_score)
ids_collect.append(beam_id)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
# generation. In these cases we simply return the highest scoring outputs.
if len(ids_collect) < self.num_beam_hyps_to_keep:
for beam_id in range(self.num_beams):
if beam_id not in ids_collect:
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append to lists
best.append(best_hyp)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict( return UserDict(
{ {
"sequences": decoded, "sequences": decoded,
......
...@@ -24,7 +24,8 @@ import torch.distributed as dist ...@@ -24,7 +24,8 @@ import torch.distributed as dist
from torch import nn from torch import nn
from .file_utils import ModelOutput from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_constraints import Constraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
...@@ -839,6 +840,7 @@ class GenerationMixin: ...@@ -839,6 +840,7 @@ class GenerationMixin:
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,
...@@ -860,7 +862,6 @@ class GenerationMixin: ...@@ -860,7 +862,6 @@ class GenerationMixin:
post](https://huggingface.co/blog/how-to-generate). post](https://huggingface.co/blog/how-to-generate).
Parameters: Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length, inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*): feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
...@@ -945,6 +946,9 @@ class GenerationMixin: ...@@ -945,6 +946,9 @@ class GenerationMixin:
Custom stopping criteria that complement the default stopping criteria built from arguments and a Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a model's config. If a stopping criteria is passed that is already created with the arguments or a
model's config an error is thrown. This feature is intended for advanced users. model's config an error is thrown. This feature is intended for advanced users.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use
of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
output_attentions (`bool`, *optional*, defaults to `False`): output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details. returned tensors for more details.
...@@ -966,7 +970,6 @@ class GenerationMixin: ...@@ -966,7 +970,6 @@ class GenerationMixin:
crash. Note that using `remove_invalid_values` can slow down generation. crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`): synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
...@@ -1140,11 +1143,14 @@ class GenerationMixin: ...@@ -1140,11 +1143,14 @@ class GenerationMixin:
) )
# 6. determine generation mode # 6. determine generation mode
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False is_constraint_gen_mode = constraints is not None
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) is_beam_sample_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None
)
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None
if num_beam_groups > num_beams: if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
...@@ -1339,6 +1345,50 @@ class GenerationMixin: ...@@ -1339,6 +1345,50 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
elif is_constraint_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")
if do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
# 10. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=constraints,
batch_size=batch_size,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
# 12. run beam search
return self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
def greedy_search( def greedy_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
...@@ -2800,6 +2850,318 @@ class GenerationMixin: ...@@ -2800,6 +2850,318 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def constrained_beam_search(
self,
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation, while satisfying a list of positive constraints. For more information, the
documentation of [`ConstrainedBeamSearchScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... ConstrainedBeamSearchScorer,
... PhrasalConstraint,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> constraint_str = "sind"
>>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token
>>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
>>> # instantiate beam scorer
>>> beam_scorer = ConstrainedBeamSearchScorer(
... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
# => ['Wie alter sind Sie?']
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
scores_for_all_vocab = next_token_scores_processed.clone()
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = (next_tokens / vocab_size).long()
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = constrained_beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
scores_for_all_vocab,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
# increase cur_len
cur_len = cur_len + 1
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = constrained_beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits: torch.FloatTensor, logits: torch.FloatTensor,
......
...@@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): ...@@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Constraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstraintListState(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhrasalConstraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject): class BeamScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject): ...@@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor ...@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
class BeamSearchTester: class BeamSearchTester:
...@@ -232,6 +233,270 @@ class BeamSearchTester: ...@@ -232,6 +233,270 @@ class BeamSearchTester:
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
class ConstrainedBeamSearchTester:
def __init__(
self,
parent,
constraints=None,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
self.constraints = constraints
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_constrained_beam_scorer(self, **kwargs):
return ConstrainedBeamSearchScorer(
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
scores_for_all_vocab, _ = (
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_constrained_beam_scorer_update(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# check all batches are done
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# beam scorer should be done
self.parent.assertTrue(constrained_beam_scorer.is_done)
# check
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_constrained_beam_scorer_finalize(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
)
constraints = constrained_beam_scorer.constraints
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
for output in sequences:
for constraint in constraints:
forced_token_ids = constraint.token_ids
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
)
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
return flag
@require_torch @require_torch
class BeamSearchTest(unittest.TestCase): class BeamSearchTest(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase): ...@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase):
def test_beam_scorer_finalize(self): def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs() inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs) self.beam_search_tester.check_beam_scores_finalize(*inputs)
@require_torch
class ConstrainedBeamSearchTest(unittest.TestCase):
def setUp(self):
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
def test_constrained_beam_hypotheses(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
def test_constrained_beam_scorer_update(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
def test_constrained_beam_scorer_finalize(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)
...@@ -27,6 +27,8 @@ if is_torch_available(): ...@@ -27,6 +27,8 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration, BartForConditionalGeneration,
BartTokenizer, BartTokenizer,
GPT2LMHeadModel, GPT2LMHeadModel,
...@@ -37,7 +39,8 @@ if is_torch_available(): ...@@ -37,7 +39,8 @@ if is_torch_available():
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
...@@ -190,6 +193,25 @@ class GenerationTesterMixin: ...@@ -190,6 +193,25 @@ class GenerationTesterMixin:
) )
return beam_kwargs, beam_scorer return beam_kwargs, beam_scorer
@staticmethod
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=batch_size,
constraints=constraints,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod @staticmethod
def _get_encoder_outputs( def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
...@@ -526,6 +548,69 @@ class GenerationTesterMixin: ...@@ -526,6 +548,69 @@ class GenerationTesterMixin:
) )
return output_generate, output_group_beam_search return output_generate, output_group_beam_search
def _constrained_beam_search_generate(
self,
model,
input_ids,
attention_mask,
max_length,
constrained_beam_scorer,
constraints,
beam_kwargs,
logits_processor,
logits_process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=constrained_beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_group_beam_search = model.constrained_beam_search(
input_ids_clone,
constrained_beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_group_beam_search
def test_greedy_generate(self): def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal # check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -719,6 +804,7 @@ class GenerationTesterMixin: ...@@ -719,6 +804,7 @@ class GenerationTesterMixin:
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor, logits_processor=logits_processor,
) )
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences` # check `generate()` and `beam_search()` are equal for `num_return_sequences`
...@@ -1085,6 +1171,164 @@ class GenerationTesterMixin: ...@@ -1085,6 +1171,164 @@ class GenerationTesterMixin:
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
) )
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
# Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
num_return_sequences = 2
max_length = 20
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
self.assertTrue(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" """Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
...@@ -1254,6 +1498,24 @@ class GenerationTesterMixin: ...@@ -1254,6 +1498,24 @@ class GenerationTesterMixin:
[encoder_expected_shape] * len(hidden_states), [encoder_expected_shape] * len(hidden_states),
) )
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
self.assertTrue(flag)
@require_torch @require_torch
class UtilsFunctionsTest(unittest.TestCase): class UtilsFunctionsTest(unittest.TestCase):
...@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase):
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
constraints = [
PhrasalConstraint(force_tokens),
PhrasalConstraint(force_tokens_2),
]
starting_text = ["The soldiers were not prepared and"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
max_length=30,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
],
)
@slow
def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# lets run beam search using 5 beams
num_beams = 5
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
model_kwargs = {
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
# instantiate beam scorer
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)
outputs = model.constrained_beam_search(
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment