Unverified Commit 2b5603f6 authored by Chan Woo Kim's avatar Chan Woo Kim Committed by GitHub
Browse files

Constrained Beam Search [without disjunctive decoding] (#15416)



* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0113aae5
...@@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation ...@@ -16,8 +16,9 @@ This page lists all the utility functions used by [`~generation_utils.Generation
[`~generation_utils.GenerationMixin.greedy_search`], [`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`], [`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], and [`~generation_utils.GenerationMixin.beam_sample`],
[`~generation_utils.GenerationMixin.group_beam_search`]. [`~generation_utils.GenerationMixin.group_beam_search`], and
[`~generation_utils.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library. Most of those are only useful if you are studying the code of the generate methods in the library.
...@@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ...@@ -190,6 +191,16 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
[[autodoc]] MaxTimeCriteria [[autodoc]] MaxTimeCriteria
- __call__ - __call__
## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] ConstraintListState
## BeamSearch ## BeamSearch
[[autodoc]] BeamScorer [[autodoc]] BeamScorer
...@@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ...@@ -200,6 +211,10 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
- process - process
- finalize - finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Utilities ## Utilities
[[autodoc]] top_k_top_p_filtering [[autodoc]] top_k_top_p_filtering
......
...@@ -612,7 +612,12 @@ if is_torch_available(): ...@@ -612,7 +612,12 @@ if is_torch_available():
"TextDatasetForNextSentencePrediction", "TextDatasetForNextSentencePrediction",
] ]
_import_structure["deepspeed"] = [] _import_structure["deepspeed"] = []
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] _import_structure["generation_beam_constraints"] = [
"Constraint",
"ConstraintListState",
"PhrasalConstraint",
]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
_import_structure["generation_logits_process"] = [ _import_structure["generation_logits_process"] = [
"ForcedBOSTokenLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
...@@ -2750,7 +2755,8 @@ if TYPE_CHECKING: ...@@ -2750,7 +2755,8 @@ if TYPE_CHECKING:
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
......
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import torch
class Constraint(ABC):
r"""Abstract base class for all constraints that can be applied during generation.
It must define how the constraint can be satisfied.
All classes that inherit Constraint must follow the requirement that
```py
completed = False
while not completed:
_, completed = constraint.update(constraint.advance())
```
will always terminate (halt).
"""
def __init__(self):
# test for the above condition
self.test()
def test(self):
"""
Tests whether this constraint has been properly defined.
"""
counter = 0
completed = False
while not completed:
if counter == 1:
self.reset()
advance = self.advance()
if not self.does_advance(advance):
raise Exception(
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
)
stepped, completed, reset = self.update(advance)
counter += 1
if counter > 10000:
raise Exception("update() does not fulfill the constraint.")
if self.remaining() != 0:
raise Exception("Custom Constraint is not defined correctly.")
@abstractmethod
def advance(self):
"""
When called, returns the token that would take this constraint one step closer to being fulfilled.
Return:
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def does_advance(self, token_id: int):
"""
Reads in a token and returns whether it creates progress.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def update(self, token_id: int):
"""
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
state of this object unlikes `does_advance(self, token_id: int)`.
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
been generated. This becomes important if token_id != desired token (refer to else statement in
PhrasalConstraint)
Args:
token_id(`int`):
The id of a newly generated token in the beam search.
Return:
stepped(`bool`):
Whether this constraint has become one step closer to being fulfuilled.
completed(`bool`):
Whether this constraint has been completely fulfilled by this token being generated.
reset (`bool`):
Whether this constraint has reset its progress by this token being generated.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def reset(self):
"""
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
a constraint is abrupted by an unwanted token.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def remaining(self):
"""
Returns the number of remaining steps of `advance()` in order to complete this constraint.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def copy(self, stateful=False):
"""
Creates a new instance of this constraint.
Args:
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
Return:
constraint(`Constraint`): The same constraint as the one being called from.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class PhrasalConstraint(Constraint):
r"""
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
Args:
token_ids (`List[int]`):
The id of the token that must be generated by the output.
"""
def __init__(self, token_ids: Union[List[int], torch.LongTensor]):
super(Constraint, self).__init__()
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int)
is_tensor = isinstance(token_ids, torch.Tensor)
is_int_tensor = (
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1
)
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
if not is_tensor:
token_ids = torch.tensor(token_ids)
self.token_ids = token_ids
self.seqlen = self.token_ids.size(0)
self.fulfilled_idx = -1 # the index of the currently fulfilled step
self.completed = False
def advance(self):
return self.token_ids[self.fulfilled_idx + 1]
def does_advance(self, token_id: int):
if self.completed:
return False
# move to cpu to guarantee no device issues.
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu()
def update(self, token_id: int):
stepped = False
completed = False
reset = False
if self.does_advance(token_id):
self.fulfilled_idx += 1
stepped = True
if self.fulfilled_idx == (self.seqlen - 1):
completed = True
self.completed = completed
else:
# failed to make progress.
reset = True
self.reset()
return stepped, completed, reset
def reset(self):
self.completed = False
self.fulfilled_idx = 0
def remaining(self):
return self.seqlen - (self.fulfilled_idx + 1)
def copy(self, stateful=False):
new_constraint = PhrasalConstraint(self.token_ids)
if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.fulfilled_idx = self.fulfilled_idx
new_constraint.completed = self.completed
return new_constraint
class ConstraintListState:
r"""
A class for beam scorers to track its progress through a list of constraints.
Args:
constraints (`List[Constraint]`):
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
"""
def __init__(self, constraints: List[Constraint]):
self.constraints = constraints
# max # of steps required to fulfill a given constraint
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)])
self.n_constraints = len(constraints)
self.completed = False
self.init_state()
def init_state(self):
self.complete_constraints = []
self.inprogress_constraint = None
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
def get_bank(self):
add = 0
if self.inprogress_constraint:
# extra points for having a constraint mid-fulfilled
add += self.max_seqlen - self.inprogress_constraint.remaining()
return (len(self.complete_constraints) * self.max_seqlen) + add
def advance(self):
"""The list of tokens to generate such that we can make progress.
By "list" we don't mean the list of token that will fully fulfill a constraint.
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
specific constraint `c_i`, we return:
`[t_k1 for k in indices of unfulfilled constraints]`
If we are in the middle of a constraint, then we return:
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
that's the only one we'll return.
"""
if self.inprogress_constraint is None:
token_list = []
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
advance = constraint.advance()
token_list.append(advance)
else:
token_list = [self.inprogress_constraint.advance()]
if len(token_list) == 0:
return None
else:
return torch.stack(token_list)
def reset(self, token_ids: Optional[torch.LongTensor]):
"""
token_ids: the tokens generated thus far to reset the state of the progress through constraints.
"""
self.init_state()
if token_ids is not None and token_ids.size(0) > 0:
for token in token_ids:
# completes or steps **one** constraint
complete, stepped = self.add(token)
# the entire list of constraints are fulfilled
if self.completed:
break
return self
def add(self, token_id: Union[int, torch.LongTensor]):
complete, stepped = False, False
if self.completed:
complete = True
stepped = False
return complete, stepped
if self.inprogress_constraint is not None:
# In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
# job, simply update the state
stepped, complete, reset = self.inprogress_constraint.update(token_id)
if reset:
# 1. If the next token breaks the progress, then we must restart.
# e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
# But that doesn't mean we self.init_state(), since we only reset the state for this particular
# constraint, not the full list of constraints.
self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
self.inprogress_constraint = None
if complete:
# 2. If the next token completes the constraint, move it to completed list, set
# inprogress to None. If there are no pending constraints either, then this full list of constraints
# is complete.
self.complete_constraints.append(self.inprogress_constraint)
self.inprogress_constraint = None
if len(self.pending_constraints) == 0:
# we're done!
self.completed = True
else:
# Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
# of constraints?
for cidx, pending_constraint in enumerate(self.pending_constraints):
if pending_constraint.does_advance(token_id):
stepped, complete, reset = pending_constraint.update(token_id)
if not stepped:
raise Exception(
"constraint.update(token_id) is not yielding incremental progress, "
"even though constraint.does_advance(token_id) is true."
)
if complete:
self.complete_constraints.append(pending_constraint)
self.inprogress_constraint = None
if not complete and stepped:
self.inprogress_constraint = pending_constraint
if complete or stepped:
# If we made any progress at all, then it's at least not a "pending constraint".
self.pending_constraints = (
self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
)
if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
# If there's no longer any pending after this and no inprogress either, then we must be
# complete.
self.completed = True
break # prevent accidentally stepping through multiple constraints with just one token.
return complete, stepped
def copy(self, stateful=True):
new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
# throughout this process. So it's at initialization state.
if stateful:
new_state.complete_constraints = [
constraint.copy(stateful=True) for constraint in self.complete_constraints
]
if self.inprogress_constraint is not None:
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
return new_state
This diff is collapsed.
This diff is collapsed.
...@@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): ...@@ -80,6 +80,27 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Constraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstraintListState(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhrasalConstraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject): class BeamScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject): ...@@ -94,6 +115,13 @@ class BeamSearchScorer(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor ...@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
class BeamSearchTester: class BeamSearchTester:
...@@ -232,6 +233,270 @@ class BeamSearchTester: ...@@ -232,6 +233,270 @@ class BeamSearchTester:
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
class ConstrainedBeamSearchTester:
def __init__(
self,
parent,
constraints=None,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
self.constraints = constraints
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_constrained_beam_scorer(self, **kwargs):
return ConstrainedBeamSearchScorer(
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
scores_for_all_vocab, _ = (
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_constrained_beam_scorer_update(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# check all batches are done
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# beam scorer should be done
self.parent.assertTrue(constrained_beam_scorer.is_done)
# check
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_constrained_beam_scorer_finalize(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
)
constraints = constrained_beam_scorer.constraints
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
for output in sequences:
for constraint in constraints:
forced_token_ids = constraint.token_ids
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
)
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
return flag
@require_torch @require_torch
class BeamSearchTest(unittest.TestCase): class BeamSearchTest(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase): ...@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase):
def test_beam_scorer_finalize(self): def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs() inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs) self.beam_search_tester.check_beam_scores_finalize(*inputs)
@require_torch
class ConstrainedBeamSearchTest(unittest.TestCase):
def setUp(self):
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
def test_constrained_beam_hypotheses(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
def test_constrained_beam_scorer_update(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
def test_constrained_beam_scorer_finalize(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)
...@@ -27,6 +27,8 @@ if is_torch_available(): ...@@ -27,6 +27,8 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration, BartForConditionalGeneration,
BartTokenizer, BartTokenizer,
GPT2LMHeadModel, GPT2LMHeadModel,
...@@ -37,7 +39,8 @@ if is_torch_available(): ...@@ -37,7 +39,8 @@ if is_torch_available():
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
...@@ -190,6 +193,25 @@ class GenerationTesterMixin: ...@@ -190,6 +193,25 @@ class GenerationTesterMixin:
) )
return beam_kwargs, beam_scorer return beam_kwargs, beam_scorer
@staticmethod
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=batch_size,
constraints=constraints,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod @staticmethod
def _get_encoder_outputs( def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
...@@ -526,6 +548,69 @@ class GenerationTesterMixin: ...@@ -526,6 +548,69 @@ class GenerationTesterMixin:
) )
return output_generate, output_group_beam_search return output_generate, output_group_beam_search
def _constrained_beam_search_generate(
self,
model,
input_ids,
attention_mask,
max_length,
constrained_beam_scorer,
constraints,
beam_kwargs,
logits_processor,
logits_process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=constrained_beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_group_beam_search = model.constrained_beam_search(
input_ids_clone,
constrained_beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_group_beam_search
def test_greedy_generate(self): def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal # check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -719,6 +804,7 @@ class GenerationTesterMixin: ...@@ -719,6 +804,7 @@ class GenerationTesterMixin:
logits_process_kwargs=logits_process_kwargs, logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor, logits_processor=logits_processor,
) )
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences` # check `generate()` and `beam_search()` are equal for `num_return_sequences`
...@@ -1085,6 +1171,164 @@ class GenerationTesterMixin: ...@@ -1085,6 +1171,164 @@ class GenerationTesterMixin:
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
) )
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
# Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
num_return_sequences = 2
max_length = 20
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
self.assertTrue(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" """Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
...@@ -1254,6 +1498,24 @@ class GenerationTesterMixin: ...@@ -1254,6 +1498,24 @@ class GenerationTesterMixin:
[encoder_expected_shape] * len(hidden_states), [encoder_expected_shape] * len(hidden_states),
) )
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
self.assertTrue(flag)
@require_torch @require_torch
class UtilsFunctionsTest(unittest.TestCase): class UtilsFunctionsTest(unittest.TestCase):
...@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase):
transition_scores_sum = transition_scores.sum(-1) transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
constraints = [
PhrasalConstraint(force_tokens),
PhrasalConstraint(force_tokens_2),
]
starting_text = ["The soldiers were not prepared and"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
max_length=30,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
],
)
@slow
def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# lets run beam search using 5 beams
num_beams = 5
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
model_kwargs = {
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
# instantiate beam scorer
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)
outputs = model.constrained_beam_search(
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment