Unverified Commit 0d84901c authored by Matt's avatar Matt Committed by GitHub
Browse files

Terminator strings for generate() (#28932)



* stash commit (will discard all of this)

* stash commit

* First commit - needs a lot of testing!

* Add a test

* Fix imports and make the tests actually test something

* Tests pass!

* Rearrange test

* Add comments (but it's still a bit confusing)

* Stop storing the tokenizer

* Comment fixup

* Fix for input_ids with a single sequence

* Update tests to test single sequences

* make fixup

* Fix incorrect use of isin()

* Expand tests to catch more cases

* Expand tests to catch more cases

* make fixup

* Fix length calculation and update tests

* Handle Ġ as a space replacement too

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Add optimizations from Joao's suggestion

* Remove TODO

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* make fixup

* Rename some variables and remove some debugging clauses for clarity

* Add tests for the sub-methods

* Clarify one test slightly

* Add stop_strings to GenerationConfig

* generate() supports stop_string arg, asks for tokenizer if not provided

* make fixup

* Cleanup code and rename variables for clarity

* Update tokenizer error

* Update tokenizer passing, handle generation on GPU

* Slightly more explanation cleanup

* More comment cleanup

* Factor out the token cleanup so it's more obvious what we're doing, and we can change it later

* Careful with that cleanup!

* Cleanup + optimizations to _get_matching_positions

* More minor performance tweaks

* Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms)

* Remove the pin_memory call

* Parallelize across all stop strings!

* Quick fix for tensor devices

* Update embeddings test for the new format

* Fix test imports

* Manual patching for BERT-like tokenizers

* Return a bool vector instead of a single True/False

* Better comment

* Better comment

* Add tests from @zucchini-nlp

* Amy's list creation nit

* tok_list -> token_list

* Push a big expanded docstring (should we put it somewhere else?)

* Expand docstrings

* Docstring fixups

* Rebase

* make fixup

* Make a properly general method for figuring out token strings

* Fix naming throughout the functions

* Move cache, refactor, fix tests

* Add comment

* Remove finished TODO

* Remove finished TODO

* make fixup

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update and shorten docstring

* Update tests to be shorter/clearer and test specific cases

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0e9d44d7
......@@ -86,6 +86,7 @@ else:
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
"StopStringCriteria",
]
_import_structure["utils"] = [
"GenerationMixin",
......@@ -224,6 +225,7 @@ if TYPE_CHECKING:
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria,
)
from .utils import (
......
......@@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin):
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
stop_strings(`str or List[str]`, *optional*):
A string or a list of strings that should terminate generation if the model outputs them.
> Parameters that control the generation strategy used
......@@ -306,6 +308,7 @@ class GenerationConfig(PushToHubMixin):
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None)
self.stop_strings = kwargs.pop("stop_strings", None)
# Parameters that control the generation strategy used
self.do_sample = kwargs.pop("do_sample", False)
......
import time
import warnings
from abc import ABC
from collections import OrderedDict
from copy import deepcopy
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import functional as F
from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import add_start_docstrings, logging
logger = logging.get_logger(__name__)
# We maintain a module-level cache of the embedding vectors for the stop string criterion
# because they are slow to compute
STOP_STRING_EMBEDDING_CACHE = OrderedDict()
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......@@ -129,6 +136,334 @@ class MaxTimeCriteria(StoppingCriteria):
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class StopStringCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever specific string sequences are generated. It preprocesses
the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings.
Generation is stopped as soon as a token is generated that completes any of the stop strings.
We want to catch any instance in which the stop string would be present in the decoded output, which means
we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string
"stop", any of the following token sequences would trigger the match:
- ["st", "op"]
- ["stop"]
- ["st", "opera"]
- ["sto", "pper"]
- ["las", "topper"]
- ["s", "to", "pped"]
Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
words, these sequences will not trigger a match:
- ["stop", "at"]
- ["st", "op", "at"]
- ["st", "opera", "tion"]
The reason these are not a match is that the stop string does not overlap with the final token. If you can remove
one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not
match that stop string. This is by design; because this check is run after each token is generated, we can't miss a
valid stop string if one is generated, but we don't want to halt generation just because the stop string exists
somewhere in the past input_ids.
How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match
process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible,
with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use
with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations.
The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at
the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of
the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for
some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this
property:
- ["st", "op"] (overlap is "op", overlap length == 2)
- ["stop"] (overlap is "stop", overlap length == 4)
- ["st", "opera"] (overlap is "op", overlap length == 2)
- ["sto", "pper"] (overlap is "p", overlap length == 1)
- ["las", "topper"] (overlap is "top", overlap length == 3)
- ["s", "to", "pped"] (overlap is "p", overlap length == 1)
It's impossible to construct a matching sequence that does not have this property (feel free to verify this
yourself). However, although this overlap between the start of the final token and the end of the stop string is
necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
consistent with the stop string.
How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
we have matched the entire stop string.
How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
matches the stop string, and so the entire string is matched.
How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
information for all tokens! For every token, we compute:
- Its overlap with the end of the stop string, if any
- The positions inside the stop string where the token matches, including matches that run off the start.
- The total length of the token
For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
a single internal matching position of 3 (again counting from the end) and a length of 1.
As long as we have this information, we can execute the algorithm above without any string comparison
operations. We simply perform the following steps:
- Check if the final token has an end-overlap with the start string
- Continue backwards, keeping track of how much of the stop string we've matched so far
- At each point, check if the next token has the current position as one of its valid positions
- Continue until either a match fails, or we completely match the whole stop string
Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
entire stop string.
In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
allow tokens to match positions that run off the start of the stop string. We add its length to the position
tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
this also counts as a match of the stop string. We have matched the entire stop string.
Args:
tokenizer (`PreTrainedTokenizer`):
The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
stop_strings (`Union[str, List[str]]`):
A list of strings that should end generation. If a string is passed, it will be treated like a
list with a single element.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
>>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
The biggest states in the USA by land area:
- Alaska
- Texas
- California
>>> # Passing one or more stop strings will halt generation after those strings are emitted
>>> # Note that generating with stop strings requires you to pass the tokenizer too
>>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
The biggest states in the USA by land area:
- Alaska
- Texas
```
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]):
if isinstance(stop_strings, str):
stop_strings = [stop_strings]
self.stop_strings: Tuple[str, ...] = tuple(stop_strings)
vocab = tokenizer.get_vocab()
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
token_list, token_indices, self.stop_strings, tokenizer
)
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
self.num_stop_strings = len(self.stop_strings)
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
(token_list, token_indices, self.stop_strings)
]
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
else:
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
clean_token_list, clean_token_indices, stop_strings
)
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
embedding_vec,
max_valid_positions,
max_valid_end_lens,
)
if len(STOP_STRING_EMBEDDING_CACHE) > 8:
STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item
return embedding_vec, max_valid_positions, max_valid_end_lens
@staticmethod
def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
"""
This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string
it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method
tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix
space addition/removal. To work around this, we add a static prefix to the start of the token, then remove
it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string().
"""
vocab = tokenizer.get_vocab()
clean_token_list = []
clean_token_indices = []
sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"]
tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base]
for token, token_idx in vocab.items():
token_string = tokenizer.convert_tokens_to_string(tokens_base + [token])
token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :]
clean_token_list.append(token_string)
clean_token_indices.append(token_idx)
return tuple(clean_token_list), tuple(clean_token_indices)
@staticmethod
def _stop_string_get_matching_positions(
token_list, token_indices, stop_strings
) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]:
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
from the end of the stop string that overlap with the start of the token, which can have more than one value.
The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full
explanation of what these values are for!"""
token_valid_positions = {}
token_end_overlaps = {}
for stop_string in stop_strings:
reversed_stop_string = stop_string[::-1]
token_valid_positions[stop_string] = {}
token_end_overlaps[stop_string] = {}
for token, tok_idx in zip(token_list, token_indices):
reversed_token = token[::-1]
matching_positions = []
possible_end_lengths = []
for i in range(1 - len(token), len(stop_string)):
if i < 0:
tok = reversed_token[-i:]
i = 0
else:
tok = reversed_token
stop = reversed_stop_string[i : i + len(tok)]
if tok.startswith(stop):
if i == 0:
possible_end_lengths.append(min(len(tok), len(stop)))
else:
matching_positions.append(i)
if matching_positions:
token_valid_positions[stop_string][tok_idx] = matching_positions
if possible_end_lengths:
token_end_overlaps[stop_string][tok_idx] = possible_end_lengths
return token_valid_positions, token_end_overlaps
@staticmethod
def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]:
"""This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs
them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values
that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!"""
token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
token_list, token_indices, stop_strings
)
max_valid_positions = max(
len(val) for positions in token_valid_positions.values() for val in positions.values()
)
max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values())
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
for i, stop_string in enumerate(stop_strings):
positions = token_valid_positions[stop_string]
end_lens = token_end_overlaps[stop_string]
# Since this is lots of very small assignments of lists, we build it with numpy rather
# than torch for speed + simplicity, then convert to torch at the end
for token_idx, valid_positions in positions.items():
gather_vec[
token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions)
] = valid_positions
for token_idx, possible_end_lens in end_lens.items():
gather_vec[
token_idx,
max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions
* len(stop_strings)
+ max_valid_end_lens * i
+ len(possible_end_lens),
] = possible_end_lens
for token, token_idx in zip(token_list, token_indices):
gather_vec[token_idx, -1] = len(token)
gather_vec = torch.tensor(gather_vec, dtype=torch.int32)
return gather_vec, max_valid_positions, max_valid_end_lens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor:
self.embedding_vec = self.embedding_vec.to(input_ids.device)
self.target_lens = self.target_lens.to(input_ids.device)
# The maximum length we need to consider is 1 token per character. Note that input_ids can also be
# *shorter* than the global max, and the code below should be ready for that
input_ids = input_ids[:, -self.maximum_token_len :]
# Flip input_ids because we're only matching strings at the end of the generated sequence
flipped_ids = torch.flip(input_ids, (1,))
# Size of the vector of positions a single token can match
max_valid_positions = self.max_valid_positions
# The embedding vec contains the valid positions, end_lengths and total lengths for each token
embedded = F.embedding(flipped_ids, self.embedding_vec)
# Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit
valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten(
-1, (self.num_stop_strings, -1)
)
# end_lengths is the number of characters from the string, counting from the end, that the token
# contains. It can have multiple values if the same token can overlap different end lengths
end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten(
-1, (self.num_stop_strings, -1)
)
# Lengths is the total length of each token. Unlike the others, it always has a single value
lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const
# Concatenate lengths onto each possible end_lengths value
lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1]))
lengths_with_ends = torch.cat([end_lengths, lengths], dim=1)
# cumsum() to get the number of matched characters in the stop string after each token
cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens
# The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not.
# First, tokens match the start of the string if they have a positive value in the end_lengths vector
initial_match = end_lengths > 0
# Tokens continue the string if the cumsum() so far is one of the valid positions for that token
# Note that we're actually tracking one cumsum() for for each possible end_length
later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
# The match vector is a boolean vector that indicates which positions have valid tokens
match = torch.cat([initial_match, later_match], dim=1)
# Once a single position does not match, all positions following that position are masked
mask = (~match).cumsum(dim=1, dtype=torch.int32)
mask = mask == 0
# The string is matched if we reached a cumsum equal to or greater than the length of the string
# before hitting the mask
string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :]
# We return a per-sample vector that is True if any stop string is matched for that sample
return torch.any(string_matches, dim=-1)
class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
......
......@@ -80,12 +80,14 @@ from .stopping_criteria import (
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria,
)
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .streamers import BaseStreamer
logger = logging.get_logger(__name__)
......@@ -885,7 +887,11 @@ class GenerationMixin:
return processors
def _get_stopping_criteria(
self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
self,
generation_config: GenerationConfig,
stopping_criteria: Optional[StoppingCriteriaList],
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
**kwargs,
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
if generation_config.max_length is not None:
......@@ -898,6 +904,14 @@ class GenerationMixin:
)
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
if generation_config.stop_strings is not None:
if tokenizer is None:
raise ValueError(
"There are one or more stop strings, either in the arguments to `generate` or in the "
"model's generation config, but we could not locate a tokenizer. When generating with "
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
......@@ -1380,6 +1394,7 @@ class GenerationMixin:
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
......@@ -1389,6 +1404,7 @@ class GenerationMixin:
synced_gpus = True
else:
synced_gpus = False
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
......@@ -1531,7 +1547,7 @@ class GenerationMixin:
# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
# 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION:
......
......@@ -16,7 +16,7 @@
import time
import unittest
from transformers import is_torch_available
from transformers import AutoTokenizer, is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import ids_tensor
......@@ -31,6 +31,7 @@ if is_torch_available():
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria,
)
......@@ -124,3 +125,134 @@ class StoppingCriteriaTestCase(unittest.TestCase):
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
self.assertEqual(len(stopping_criteria), 1)
def test_stop_string_criteria(self):
true_strings = [
"<|im_start|><|im_end|>",
"<|im_start|><|im_end|<|im_end|>",
">><|im_start|>>stop",
"stop",
"e nd",
]
false_strings = [
"<|im_start|><|im_end|",
"<|im_start|><|im_end|<|im_end|",
"<|im_end|><|im_start|>",
"<|im_end|<>stop<|im_end|",
"end",
"en d",
"eNd",
"<|im_end|",
"|im_end|>",
"s",
]
stop_strings = ["<|im_end|>", "stop", "e nd"]
# Use a tokenizer that won't actually have special tokens for these
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
# Now try it with a tokenizer where those are actually special tokens
tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b")
tokenizer.padding_side = "left"
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
def test_stop_string_matching_positions(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
token_indices = list(range(len(token_list)))
all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
)
valid_positions = {
token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items()
}
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
def test_stop_string_embedding_vecs(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
token_indices = list(range(len(token_list)))
embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec(
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
)
# Positions inside the stop string where the token matches (excluding end overlaps)
valid_positions = embedding_vec[:, 0].tolist()
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
# Overlap lengths between end of stop string and start of token
end_overlaps = embedding_vec[:, 1].tolist()
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
# Length of each token
token_lengths = embedding_vec[:, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])
def test_criterias_per_row(self):
text = "They completed the challenging puzzle, revealing the hidden image at the end"
stop_strings = ["end"]
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
scores = None
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=20),
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
]
)
# trigger stopping when at leat one criteria is satisfied, one value per batch
self.assertTrue(criteria(inputs["input_ids"], scores))
# return False when neither is satisfied
self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores))
def test_criterias_per_row_batched(self):
text = [
"They completed the challenging puzzle, revealing the hidden image at the end",
"Today a dragon flew over France",
"The aroma of freshly baked pizza filled the kitchen",
]
stop_strings = ["end"]
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False)
scores = None
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=20),
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
]
)
# trigger stopping when at leat one criteria is satisfied
self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False])
# False when neither is satisfied
self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False])
......@@ -2330,6 +2330,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
@slow
def test_per_row_stopping_criteria(self):
text = [
"They completed the challenging puzzle, revealing the hidden",
"Today a dragon flew over France",
"The aroma of freshly baked pizza filled the kitchen",
]
stop_strings = ["secrets"]
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to(
torch_device
)
# normal generation with one stopping criteria
out = model.generate(input_ids, max_length=15)
out_text = tokenizer.batch_decode(out)
expected_out = [
"They completed the challenging puzzle, revealing the hidden secrets of the world.\n",
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
]
self.assertListEqual(out_text, expected_out)
# generation should stop at "secrets" for first batch only, filling the rest with eos tokens
out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer)
out_text = tokenizer.batch_decode(out)
expected_out = [
"They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>",
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
]
self.assertListEqual(out_text, expected_out)
def test_constrained_beam_search_mixin_type_checks(self):
# PT-only test: TF doesn't have constrained beam search
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment