Unverified Commit 5c6f57ee authored by Chan Woo Kim's avatar Chan Woo Kim Committed by GitHub
Browse files

Constrained Beam Search [*With* Disjunctive Decoding] (#15761)



* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements

* finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing.

* fixed bug found in constrained beam search that used beam_idx that were not global across all the batches

* disjunctive constraint working 100% correctly

* passing all tests

* Accidentally included mlruns

* Update src/transformers/generation_beam_constraints.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_beam_constraints.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* complete overhaul of type complexities and other nits

* strict type checks in generate()

* fixing second round of feedback by narsil

* fixed failing generation test because of type check overhaul

* generation test fail fix

* fixing test fails
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 040c11f6
...@@ -229,6 +229,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -229,6 +229,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] PhrasalConstraint [[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState [[autodoc]] ConstraintListState
## BeamSearch ## BeamSearch
......
...@@ -623,6 +623,7 @@ if is_torch_available(): ...@@ -623,6 +623,7 @@ if is_torch_available():
_import_structure["generation_beam_constraints"] = [ _import_structure["generation_beam_constraints"] = [
"Constraint", "Constraint",
"ConstraintListState", "ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint", "PhrasalConstraint",
] ]
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
...@@ -2857,7 +2858,12 @@ if TYPE_CHECKING: ...@@ -2857,7 +2858,12 @@ if TYPE_CHECKING:
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation_beam_constraints import Constraint, ConstraintListState, PhrasalConstraint from .generation_beam_constraints import (
Constraint,
ConstraintListState,
DisjunctiveConstraint,
PhrasalConstraint,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Union from typing import List, Optional
import torch
class Constraint(ABC): class Constraint(ABC):
...@@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint): ...@@ -137,37 +135,38 @@ class PhrasalConstraint(Constraint):
The id of the token that must be generated by the output. The id of the token that must be generated by the output.
""" """
def __init__(self, token_ids: Union[List[int], torch.LongTensor]): def __init__(self, token_ids: List[int]):
super(Constraint, self).__init__() super(Constraint, self).__init__()
is_int_list = isinstance(token_ids, List) and isinstance(token_ids[0], int) if not isinstance(token_ids, list) or len(token_ids) == 0:
is_tensor = isinstance(token_ids, torch.Tensor) raise ValueError(f"`token_ids` has to be a non-emtpy list, but is {token_ids}.")
is_int_tensor = ( if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
is_tensor and token_ids.dtype in [torch.int16, torch.int32, torch.int64] and len(token_ids.size()) == 1 raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
)
not_positive = torch.any(token_ids < 0) if is_tensor else len([t for t in token_ids if t < 0]) > 0
if isinstance(token_ids, int) or not (is_int_list or is_int_tensor) or not_positive:
raise ValueError(f"`token_ids` has to be a single list or tensor of positive integers but is {token_ids}")
if not is_tensor:
token_ids = torch.tensor(token_ids)
self.token_ids = token_ids self.token_ids = token_ids
self.seqlen = self.token_ids.size(0) self.seqlen = len(self.token_ids)
self.fulfilled_idx = -1 # the index of the currently fulfilled step self.fulfilled_idx = -1 # the index of the currently fulfilled step
self.completed = False self.completed = False
def advance(self): def advance(self):
if self.completed:
return None
return self.token_ids[self.fulfilled_idx + 1] return self.token_ids[self.fulfilled_idx + 1]
def does_advance(self, token_id: int): def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
if self.completed: if self.completed:
return False return False
# move to cpu to guarantee no device issues.
return token_id.cpu() == self.token_ids[self.fulfilled_idx + 1].cpu() return token_id == self.token_ids[self.fulfilled_idx + 1]
def update(self, token_id: int): def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
stepped = False stepped = False
completed = False completed = False
reset = False reset = False
...@@ -202,6 +201,151 @@ class PhrasalConstraint(Constraint): ...@@ -202,6 +201,151 @@ class PhrasalConstraint(Constraint):
return new_constraint return new_constraint
class DisjunctiveTrie:
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
r"""
A helper class that builds a trie with the words represented in `nested_token_ids`.
"""
self.max_height = max([len(one) for one in nested_token_ids])
root = dict()
for token_ids in nested_token_ids:
level = root
for tidx, token_id in enumerate(token_ids):
if token_id not in level:
level[token_id] = dict()
level = level[token_id]
if no_subsets and self.has_subsets(root, nested_token_ids):
raise ValueError(
f"Each list in `nested_token_ids` can't be a complete subset of another list, but is {nested_token_ids}."
)
self.trie = root
def next_tokens(self, current_seq):
"""
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
"""
start = self.trie
for current_token in current_seq:
start = start[current_token]
next_tokens = list(start.keys())
return next_tokens
def reached_leaf(self, current_seq):
next_tokens = self.next_tokens(current_seq)
return len(next_tokens) == 0
def count_leaves(self, root):
next_nodes = list(root.values())
if len(next_nodes) == 0:
return 1
else:
return sum([self.count_leaves(nn) for nn in next_nodes])
def has_subsets(self, trie, nested_token_ids):
"""
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
"""
leaf_count = self.count_leaves(trie)
return len(nested_token_ids) != leaf_count
class DisjunctiveConstraint(Constraint):
r"""
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
Args:
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint
is fulfilled by generating just one from the list of words.
"""
def __init__(self, nested_token_ids: List[List[int]]):
super(Constraint, self).__init__()
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
raise ValueError(f"`nested_token_ids` has to be a non-emtpy list, but is {nested_token_ids}.")
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in nested_token_ids
):
raise ValueError(
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
)
self.trie = DisjunctiveTrie(nested_token_ids)
self.token_ids = nested_token_ids
self.seqlen = self.trie.max_height
self.current_seq = []
self.completed = False
def advance(self):
token_list = self.trie.next_tokens(self.current_seq)
if len(token_list) == 0:
return None
else:
return token_list
def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
next_tokens = self.trie.next_tokens(self.current_seq)
return token_id in next_tokens
def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
stepped = False
completed = False
reset = False
if self.does_advance(token_id):
self.current_seq.append(token_id)
stepped = True
else:
reset = True
self.reset()
completed = self.trie.reached_leaf(self.current_seq)
self.completed = completed
return stepped, completed, reset
def reset(self):
self.completed = False
self.current_seq = []
def remaining(self):
if self.completed:
# since this can be completed without reaching max height
return 0
else:
return self.seqlen - len(self.current_seq)
def copy(self, stateful=False):
new_constraint = DisjunctiveConstraint(self.token_ids)
if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.current_seq = self.current_seq
new_constraint.completed = self.completed
return new_constraint
class ConstraintListState: class ConstraintListState:
r""" r"""
A class for beam scorers to track its progress through a list of constraints. A class for beam scorers to track its progress through a list of constraints.
...@@ -215,7 +359,7 @@ class ConstraintListState: ...@@ -215,7 +359,7 @@ class ConstraintListState:
self.constraints = constraints self.constraints = constraints
# max # of steps required to fulfill a given constraint # max # of steps required to fulfill a given constraint
self.max_seqlen = max([c.seqlen for c in constraints if isinstance(c, PhrasalConstraint)]) self.max_seqlen = max([c.seqlen for c in constraints])
self.n_constraints = len(constraints) self.n_constraints = len(constraints)
self.completed = False self.completed = False
...@@ -249,26 +393,33 @@ class ConstraintListState: ...@@ -249,26 +393,33 @@ class ConstraintListState:
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint, Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
that's the only one we'll return. that's the only one we'll return.
""" """
token_list = []
if self.inprogress_constraint is None: if self.inprogress_constraint is None:
token_list = []
for constraint in self.pending_constraints: # "pending" == "unfulfilled yet" for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
advance = constraint.advance() advance = constraint.advance()
token_list.append(advance) if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list.extend(advance)
else: else:
token_list = [self.inprogress_constraint.advance()] advance = self.inprogress_constraint.advance()
if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list.extend(advance)
if len(token_list) == 0: if len(token_list) == 0:
return None return None
else: else:
return torch.stack(token_list) return token_list
def reset(self, token_ids: Optional[torch.LongTensor]): def reset(self, token_ids: Optional[List[int]]):
""" """
token_ids: the tokens generated thus far to reset the state of the progress through constraints. token_ids: the tokens generated thus far to reset the state of the progress through constraints.
""" """
self.init_state() self.init_state()
if token_ids is not None and token_ids.size(0) > 0: if token_ids is not None:
for token in token_ids: for token in token_ids:
# completes or steps **one** constraint # completes or steps **one** constraint
complete, stepped = self.add(token) complete, stepped = self.add(token)
...@@ -277,9 +428,10 @@ class ConstraintListState: ...@@ -277,9 +428,10 @@ class ConstraintListState:
if self.completed: if self.completed:
break break
return self def add(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.")
def add(self, token_id: Union[int, torch.LongTensor]):
complete, stepped = False, False complete, stepped = False, False
if self.completed: if self.completed:
...@@ -324,8 +476,8 @@ class ConstraintListState: ...@@ -324,8 +476,8 @@ class ConstraintListState:
if not stepped: if not stepped:
raise Exception( raise Exception(
"constraint.update(token_id) is not yielding incremental progress, " "`constraint.update(token_id)` is not yielding incremental progress, "
"even though constraint.does_advance(token_id) is true." "even though `constraint.does_advance(token_id)` is true."
) )
if complete: if complete:
......
...@@ -443,7 +443,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -443,7 +443,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
def check_completes_constraints(self, sequence): def check_completes_constraints(self, sequence):
new_state = self.make_constraint_states(1)[0] new_state = self.make_constraint_states(1)[0]
new_state = new_state.reset(sequence) new_state.reset(sequence)
return new_state.completed return new_state.completed
def process( def process(
...@@ -484,6 +484,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -484,6 +484,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
all all
non-finished beams. non-finished beams.
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
added added
to the non-finished beam_hypotheses. to the non-finished beam_hypotheses.
...@@ -537,7 +538,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -537,7 +538,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
continue continue
completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx]) completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
if completes_constraint: if completes_constraint:
beam_hyp.add( beam_hyp.add(
input_ids[batch_beam_idx].clone(), input_ids[batch_beam_idx].clone(),
...@@ -628,23 +629,23 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -628,23 +629,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# hypotheses. # hypotheses.
topk_state = topk_contraint_states[seq_idx] topk_state = topk_contraint_states[seq_idx]
topk_state.reset(full_hypotheses[seq_idx]) topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
advance_state = advance_constraint_states[seq_idx] advance_state = advance_constraint_states[seq_idx]
advance_state.reset(pre_seq) advance_state.reset(pre_seq.cpu().tolist())
if not advance_state.completed: if not advance_state.completed:
advance_tokens = advance_state.advance() advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
for advance_token in advance_tokens.to(device): for advance_token in advance_tokens:
# since adding each `advance_token` leads to a different hypothesis, create new state instance. # since adding each `advance_token` leads to a different hypothesis, create new state instance.
new_state = advance_state.copy(stateful=True) new_state = advance_state.copy(stateful=True)
new_state.add(advance_token) new_state.add(advance_token.cpu().tolist())
advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist() advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
if advance_seq not in track_new["new_seqs"]: if advance_seq not in track_new["new_seqs"]:
# prevent duplicates, which are basically bound to happen in this process. # prevent duplicates, which are basically bound to happen in this process.
track_new["new_seqs"].append(advance_seq) track_new["new_seqs"].append(advance_seq)
track_new["new_indices"].append(seq_idx) track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
track_new["new_tokens"].append(advance_token) track_new["new_tokens"].append(advance_token)
track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token)) track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
track_new["new_states"].append(new_state) track_new["new_states"].append(new_state)
...@@ -673,8 +674,9 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -673,8 +674,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
advance_state = advance_constraint_states[seq_idx] advance_state = advance_constraint_states[seq_idx]
advance_state.reset(advance_seq)
advance_seq = advance_seq.cpu().tolist() advance_seq = advance_seq.cpu().tolist()
advance_state.reset(advance_seq)
if advance_seq not in track_new["new_seqs"]: if advance_seq not in track_new["new_seqs"]:
# but still don't want to have duplicates # but still don't want to have duplicates
track_new["new_seqs"].append(advance_seq) track_new["new_seqs"].append(advance_seq)
...@@ -745,7 +747,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -745,7 +747,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
final_score = final_beam_scores[batch_beam_idx].item() final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx] final_tokens = input_ids[batch_beam_idx]
completes_constraint = self.check_completes_constraints(final_tokens) completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint: if completes_constraint:
beam_hyp.add(final_tokens, final_score) beam_hyp.add(final_tokens, final_score)
ids_collect.append(beam_id) ids_collect.append(beam_id)
......
...@@ -24,7 +24,7 @@ import torch.distributed as dist ...@@ -24,7 +24,7 @@ import torch.distributed as dist
from torch import nn from torch import nn
from .file_utils import ModelOutput from .file_utils import ModelOutput
from .generation_beam_constraints import Constraint from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
...@@ -818,6 +818,7 @@ class GenerationMixin: ...@@ -818,6 +818,7 @@ class GenerationMixin:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None, bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None, bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
...@@ -904,6 +905,11 @@ class GenerationMixin: ...@@ -904,6 +905,11 @@ class GenerationMixin:
List of token ids that are not allowed to be generated. In order to get the token ids of the words that List of token ids that are not allowed to be generated. In order to get the token ids of the words that
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`. add_special_tokens=False).input_ids`.
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
where one can allow different forms of each word.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. The number of independently computed returned sequences for each element in the batch.
max_time(`float`, *optional*, defaults to None): max_time(`float`, *optional*, defaults to None):
...@@ -1038,10 +1044,18 @@ class GenerationMixin: ...@@ -1038,10 +1044,18 @@ class GenerationMixin:
>>> bad_words_ids = tokenizer( >>> bad_words_ids = tokenizer(
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False ... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
>>> ).input_ids >>> ).input_ids
>>> # get tokens of words that we want generated
>>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids
>>> # encode input context >>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated >>> # generate sequences without allowing bad_words to be generated
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) >>> outputs = model.generate(
... input_ids=input_ids,
... max_length=20,
... do_sample=True,
... bad_words_ids=bad_words_ids,
... force_words_ids=force_words_ids,
... )
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
```""" ```"""
# 1. Set generation parameters if not already defined # 1. Set generation parameters if not already defined
...@@ -1138,14 +1152,20 @@ class GenerationMixin: ...@@ -1138,14 +1152,20 @@ class GenerationMixin:
) )
# 6. determine generation mode # 6. determine generation mode
is_constraint_gen_mode = constraints is not None is_constraint_gen_mode = constraints is not None or force_words_ids is not None
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and constraints is None is_greedy_gen_mode = (
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and constraints is None (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and constraints is None )
is_sample_gen_mode = (
(num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
)
is_beam_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
)
is_beam_sample_gen_mode = ( is_beam_sample_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and constraints is None (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
) )
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and constraints is None is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
if num_beam_groups > num_beams: if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
...@@ -1356,9 +1376,46 @@ class GenerationMixin: ...@@ -1356,9 +1376,46 @@ class GenerationMixin:
if num_beam_groups is not None and num_beam_groups > 1: if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.") raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if constraints is not None:
final_constraints = constraints
if force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
f"of positive integers, but is {force_words_ids}."
)
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
typeerror()
for word_ids in force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(not isinstance(token_ids, list) for token_ids in word_ids):
typeerror()
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
):
typeerror()
constraint = DisjunctiveConstraint(word_ids)
else:
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
typeerror()
constraint = PhrasalConstraint(word_ids)
final_constraints.append(constraint)
# 10. prepare beam search scorer # 10. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer( constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=constraints, constraints=final_constraints,
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
......
...@@ -94,6 +94,13 @@ class ConstraintListState(metaclass=DummyObject): ...@@ -94,6 +94,13 @@ class ConstraintListState(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DisjunctiveConstraint(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhrasalConstraint(metaclass=DummyObject): class PhrasalConstraint(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from transformers.generation_beam_constraints import DisjunctiveConstraint
@require_torch
class ConstraintTest(unittest.TestCase):
def test_input_types(self):
# For consistency across different places the DisjunctiveConstraint is called,
# dc.token_ids is a list of integers. It is also initialized only by integers.
cset = [[1, 2, 4], [1, 2, 3, 4]]
dc = DisjunctiveConstraint(cset)
self.assertTrue(isinstance(dc.token_ids, list))
with self.assertRaises(ValueError):
DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]]))
with self.assertRaises(ValueError):
DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])])
def test_check_illegal_input(self):
# We can't have constraints that are complete subsets of another. This leads to a preverse
# interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint?
# It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially
# fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm
# will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it).
cset = [[1, 2], [1, 2, 3, 4]]
with self.assertRaises(ValueError):
DisjunctiveConstraint(cset) # fails here
def test_example_progression(self):
cset = [[1, 2, 3], [1, 2, 4]]
dc = DisjunctiveConstraint(cset)
stepped, completed, reset = dc.update(1)
desired = stepped is True and completed is False and reset is False
self.assertTrue(desired)
self.assertTrue(not dc.completed)
self.assertTrue(dc.current_seq == [1])
stepped, completed, reset = dc.update(2)
desired = stepped is True and completed is False and reset is False
self.assertTrue(desired)
self.assertTrue(not dc.completed)
self.assertTrue(dc.current_seq == [1, 2])
stepped, completed, reset = dc.update(3)
desired = stepped is True and completed is True and reset is False
self.assertTrue(desired)
self.assertTrue(dc.completed) # Completed!
self.assertTrue(dc.current_seq == [1, 2, 3])
def test_example_progression_unequal_three_mid_and_reset(self):
cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]]
dc = DisjunctiveConstraint(cset)
stepped, completed, reset = dc.update(1)
self.assertTrue(not dc.completed)
self.assertTrue(dc.current_seq == [1])
stepped, completed, reset = dc.update(2)
self.assertTrue(not dc.completed)
self.assertTrue(dc.current_seq == [1, 2])
stepped, completed, reset = dc.update(4)
self.assertTrue(not dc.completed)
self.assertTrue(dc.current_seq == [1, 2, 4])
stepped, completed, reset = dc.update(5)
self.assertTrue(dc.completed) # Completed!
self.assertTrue(dc.current_seq == [1, 2, 4, 5])
dc.reset()
stepped, completed, reset = dc.update(1)
self.assertTrue(not dc.completed)
self.assertTrue(dc.remaining() == 3)
self.assertTrue(dc.current_seq == [1])
stepped, completed, reset = dc.update(2)
self.assertTrue(not dc.completed)
self.assertTrue(dc.remaining() == 2)
self.assertTrue(dc.current_seq == [1, 2])
stepped, completed, reset = dc.update(5)
self.assertTrue(dc.completed) # Completed!
self.assertTrue(dc.remaining() == 0)
self.assertTrue(dc.current_seq == [1, 2, 5])
...@@ -25,7 +25,7 @@ from ..test_modeling_common import floats_tensor, ids_tensor ...@@ -25,7 +25,7 @@ from ..test_modeling_common import floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.generation_beam_constraints import PhrasalConstraint from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
...@@ -260,10 +260,10 @@ class ConstrainedBeamSearchTester: ...@@ -260,10 +260,10 @@ class ConstrainedBeamSearchTester:
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None: if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0] force_tokens = torch.randint(10, 50, (1, 2))[0].tolist()
constraints = [ disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist()
PhrasalConstraint(force_tokens),
] constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)]
self.constraints = constraints self.constraints = constraints
# cannot be randomely generated # cannot be randomely generated
self.eos_token_id = vocab_size + 1 self.eos_token_id = vocab_size + 1
...@@ -331,7 +331,13 @@ class ConstrainedBeamSearchTester: ...@@ -331,7 +331,13 @@ class ConstrainedBeamSearchTester:
): ):
# check too many eos tokens # check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer() constrained_beam_scorer = self.prepare_constrained_beam_scorer()
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() stacked_token_ids = []
for constraint in self.constraints:
token_ids = constraint.token_ids
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
stacked_token_ids = stacked_token_ids + token_ids
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
fulfill_len = fulfilling_sequence.size(0) fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence input_ids[:, :fulfill_len] = fulfilling_sequence
...@@ -398,7 +404,14 @@ class ConstrainedBeamSearchTester: ...@@ -398,7 +404,14 @@ class ConstrainedBeamSearchTester:
max_length = self.sequence_length + 1 max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints # for testing finalize, we do want to have fulfilled constraints
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten() stacked_token_ids = []
for constraint in self.constraints:
token_ids = constraint.token_ids
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
stacked_token_ids = stacked_token_ids + token_ids
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
fulfill_len = fulfilling_sequence.size(0) fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence input_ids[:, :fulfill_len] = fulfilling_sequence
...@@ -451,9 +464,17 @@ class ConstrainedBeamSearchTester: ...@@ -451,9 +464,17 @@ class ConstrainedBeamSearchTester:
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled # test that the constraint is indeed fulfilled
for output in sequences: for (output, constraint) in [(s, c) for s in sequences for c in constraints]:
for constraint in constraints: forced_token_ids = constraint.token_ids
forced_token_ids = constraint.token_ids if isinstance(forced_token_ids[0], list):
# disjunctive case
flag = False
for token_ids in forced_token_ids:
if self._check_sequence_inside_sequence(output, token_ids):
flag = True
break
self.parent.assertEqual(flag, True)
else:
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True) self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
...@@ -479,18 +500,23 @@ class ConstrainedBeamSearchTester: ...@@ -479,18 +500,23 @@ class ConstrainedBeamSearchTester:
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device. # set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0) if not isinstance(tensor_1, list):
tensor_1 = tensor_1.cpu().tolist()
if not isinstance(tensor_2, list):
tensor_2 = tensor_2.cpu().tolist()
in_order = len(tensor_1) <= len(tensor_2)
longer = tensor_2 if in_order else tensor_1 longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2 shorter = tensor_1 if in_order else tensor_2
flag = False flag = False
chunk_size = shorter.size(0) chunk_size = len(shorter)
for chunk_idx in range(longer.size(0) - chunk_size + 1): for chunk_idx in range(len(longer) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size] subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter): if subseq == shorter:
flag = True flag = True
break break
......
...@@ -39,7 +39,7 @@ if is_torch_available(): ...@@ -39,7 +39,7 @@ if is_torch_available():
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from transformers.generation_beam_constraints import PhrasalConstraint from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
...@@ -1202,7 +1202,7 @@ class GenerationTesterMixin: ...@@ -1202,7 +1202,7 @@ class GenerationTesterMixin:
min_id = 3 min_id = 3
max_id = 100 max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
...@@ -1227,7 +1227,7 @@ class GenerationTesterMixin: ...@@ -1227,7 +1227,7 @@ class GenerationTesterMixin:
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences` # check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
# Sample constraints # Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
...@@ -1288,7 +1288,7 @@ class GenerationTesterMixin: ...@@ -1288,7 +1288,7 @@ class GenerationTesterMixin:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points # otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3 min_id = 3
max_id = 100 max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0] force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
] ]
...@@ -1499,18 +1499,23 @@ class GenerationTesterMixin: ...@@ -1499,18 +1499,23 @@ class GenerationTesterMixin:
) )
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device. # set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0) if not isinstance(tensor_1, list):
tensor_1 = tensor_1.cpu().tolist()
if not isinstance(tensor_2, list):
tensor_2 = tensor_2.cpu().tolist()
in_order = len(tensor_1) <= len(tensor_2)
longer = tensor_2 if in_order else tensor_1 longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2 shorter = tensor_1 if in_order else tensor_2
flag = False flag = False
chunk_size = shorter.size(0) chunk_size = len(shorter)
for chunk_idx in range(longer.size(0) - chunk_size + 1): for chunk_idx in range(len(longer) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size] subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter): if subseq == shorter:
flag = True flag = True
break break
...@@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0] force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0] force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
constraints = [ constraints = [
PhrasalConstraint(force_tokens), PhrasalConstraint(force_tokens),
...@@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase):
], ],
) )
@slow
def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
).input_ids
constraints = [
PhrasalConstraint(force_phrase),
DisjunctiveConstraint(flexible_phrases),
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
# max_length=20,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
],
)
@slow
def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
force_words_ids = [
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
],
)
@slow
def test_constrained_beam_search_example_translation_mixin(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["sind"]
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
@slow @slow
def test_constrained_beam_search_example_integration(self): def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base") tokenizer = AutoTokenizer.from_pretrained("t5-base")
...@@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase):
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"]) self.assertListEqual(outputs, ["Wie alter sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[-1]])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment