"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "62cbd8423b3b8ab173fba9cf8c7346eca5f76d32"
Unverified Commit 0d84901c authored by Matt's avatar Matt Committed by GitHub
Browse files

Terminator strings for generate() (#28932)



* stash commit (will discard all of this)

* stash commit

* First commit - needs a lot of testing!

* Add a test

* Fix imports and make the tests actually test something

* Tests pass!

* Rearrange test

* Add comments (but it's still a bit confusing)

* Stop storing the tokenizer

* Comment fixup

* Fix for input_ids with a single sequence

* Update tests to test single sequences

* make fixup

* Fix incorrect use of isin()

* Expand tests to catch more cases

* Expand tests to catch more cases

* make fixup

* Fix length calculation and update tests

* Handle Ġ as a space replacement too

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Add optimizations from Joao's suggestion

* Remove TODO

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* make fixup

* Rename some variables and remove some debugging clauses for clarity

* Add tests for the sub-methods

* Clarify one test slightly

* Add stop_strings to GenerationConfig

* generate() supports stop_string arg, asks for tokenizer if not provided

* make fixup

* Cleanup code and rename variables for clarity

* Update tokenizer error

* Update tokenizer passing, handle generation on GPU

* Slightly more explanation cleanup

* More comment cleanup

* Factor out the token cleanup so it's more obvious what we're doing, and we can change it later

* Careful with that cleanup!

* Cleanup + optimizations to _get_matching_positions

* More minor performance tweaks

* Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms)

* Remove the pin_memory call

* Parallelize across all stop strings!

* Quick fix for tensor devices

* Update embeddings test for the new format

* Fix test imports

* Manual patching for BERT-like tokenizers

* Return a bool vector instead of a single True/False

* Better comment

* Better comment

* Add tests from @zucchini-nlp

* Amy's list creation nit

* tok_list -> token_list

* Push a big expanded docstring (should we put it somewhere else?)

* Expand docstrings

* Docstring fixups

* Rebase

* make fixup

* Make a properly general method for figuring out token strings

* Fix naming throughout the functions

* Move cache, refactor, fix tests

* Add comment

* Remove finished TODO

* Remove finished TODO

* make fixup

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update and shorten docstring

* Update tests to be shorter/clearer and test specific cases

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0e9d44d7
...@@ -86,6 +86,7 @@ else: ...@@ -86,6 +86,7 @@ else:
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"validate_stopping_criteria", "validate_stopping_criteria",
"StopStringCriteria",
] ]
_import_structure["utils"] = [ _import_structure["utils"] = [
"GenerationMixin", "GenerationMixin",
...@@ -224,6 +225,7 @@ if TYPE_CHECKING: ...@@ -224,6 +225,7 @@ if TYPE_CHECKING:
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria, validate_stopping_criteria,
) )
from .utils import ( from .utils import (
......
...@@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin):
max_time(`float`, *optional*): max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed. the current pass after allocated time has been passed.
stop_strings(`str or List[str]`, *optional*):
A string or a list of strings that should terminate generation if the model outputs them.
> Parameters that control the generation strategy used > Parameters that control the generation strategy used
...@@ -306,6 +308,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -306,6 +308,7 @@ class GenerationConfig(PushToHubMixin):
self.min_new_tokens = kwargs.pop("min_new_tokens", None) self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False) self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None) self.max_time = kwargs.pop("max_time", None)
self.stop_strings = kwargs.pop("stop_strings", None)
# Parameters that control the generation strategy used # Parameters that control the generation strategy used
self.do_sample = kwargs.pop("do_sample", False) self.do_sample = kwargs.pop("do_sample", False)
......
...@@ -80,12 +80,14 @@ from .stopping_criteria import ( ...@@ -80,12 +80,14 @@ from .stopping_criteria import (
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria, validate_stopping_criteria,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .streamers import BaseStreamer from .streamers import BaseStreamer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -885,7 +887,11 @@ class GenerationMixin: ...@@ -885,7 +887,11 @@ class GenerationMixin:
return processors return processors
def _get_stopping_criteria( def _get_stopping_criteria(
self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] self,
generation_config: GenerationConfig,
stopping_criteria: Optional[StoppingCriteriaList],
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
**kwargs,
) -> StoppingCriteriaList: ) -> StoppingCriteriaList:
criteria = StoppingCriteriaList() criteria = StoppingCriteriaList()
if generation_config.max_length is not None: if generation_config.max_length is not None:
...@@ -898,6 +904,14 @@ class GenerationMixin: ...@@ -898,6 +904,14 @@ class GenerationMixin:
) )
if generation_config.max_time is not None: if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
if generation_config.stop_strings is not None:
if tokenizer is None:
raise ValueError(
"There are one or more stop strings, either in the arguments to `generate` or in the "
"model's generation config, but we could not locate a tokenizer. When generating with "
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config.eos_token_id is not None: if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
...@@ -1380,6 +1394,7 @@ class GenerationMixin: ...@@ -1380,6 +1394,7 @@ class GenerationMixin:
""" """
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
...@@ -1389,6 +1404,7 @@ class GenerationMixin: ...@@ -1389,6 +1404,7 @@ class GenerationMixin:
synced_gpus = True synced_gpus = True
else: else:
synced_gpus = False synced_gpus = False
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
...@@ -1531,7 +1547,7 @@ class GenerationMixin: ...@@ -1531,7 +1547,7 @@ class GenerationMixin:
# 9. prepare stopping criteria # 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria( prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
) )
# 10. go into different generation modes # 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_mode == GenerationMode.ASSISTED_GENERATION:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
import unittest import unittest
from transformers import is_torch_available from transformers import AutoTokenizer, is_torch_available
from transformers.testing_utils import require_torch, torch_device from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import ids_tensor from ..test_modeling_common import ids_tensor
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
MaxNewTokensCriteria, MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteriaList, StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria, validate_stopping_criteria,
) )
...@@ -124,3 +125,134 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -124,3 +125,134 @@ class StoppingCriteriaTestCase(unittest.TestCase):
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
self.assertEqual(len(stopping_criteria), 1) self.assertEqual(len(stopping_criteria), 1)
def test_stop_string_criteria(self):
true_strings = [
"<|im_start|><|im_end|>",
"<|im_start|><|im_end|<|im_end|>",
">><|im_start|>>stop",
"stop",
"e nd",
]
false_strings = [
"<|im_start|><|im_end|",
"<|im_start|><|im_end|<|im_end|",
"<|im_end|><|im_start|>",
"<|im_end|<>stop<|im_end|",
"end",
"en d",
"eNd",
"<|im_end|",
"|im_end|>",
"s",
]
stop_strings = ["<|im_end|>", "stop", "e nd"]
# Use a tokenizer that won't actually have special tokens for these
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
# Now try it with a tokenizer where those are actually special tokens
tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b")
tokenizer.padding_side = "left"
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
def test_stop_string_matching_positions(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
token_indices = list(range(len(token_list)))
all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
)
valid_positions = {
token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items()
}
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
def test_stop_string_embedding_vecs(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
token_indices = list(range(len(token_list)))
embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec(
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
)
# Positions inside the stop string where the token matches (excluding end overlaps)
valid_positions = embedding_vec[:, 0].tolist()
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
# Overlap lengths between end of stop string and start of token
end_overlaps = embedding_vec[:, 1].tolist()
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
# Length of each token
token_lengths = embedding_vec[:, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])
def test_criterias_per_row(self):
text = "They completed the challenging puzzle, revealing the hidden image at the end"
stop_strings = ["end"]
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
scores = None
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=20),
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
]
)
# trigger stopping when at leat one criteria is satisfied, one value per batch
self.assertTrue(criteria(inputs["input_ids"], scores))
# return False when neither is satisfied
self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores))
def test_criterias_per_row_batched(self):
text = [
"They completed the challenging puzzle, revealing the hidden image at the end",
"Today a dragon flew over France",
"The aroma of freshly baked pizza filled the kitchen",
]
stop_strings = ["end"]
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False)
scores = None
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=20),
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
]
)
# trigger stopping when at leat one criteria is satisfied
self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False])
# False when neither is satisfied
self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False])
...@@ -2330,6 +2330,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2330,6 +2330,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs, ["Wie alt sind Sie?"]) self.assertListEqual(outputs, ["Wie alt sind Sie?"])
@slow
def test_per_row_stopping_criteria(self):
text = [
"They completed the challenging puzzle, revealing the hidden",
"Today a dragon flew over France",
"The aroma of freshly baked pizza filled the kitchen",
]
stop_strings = ["secrets"]
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to(
torch_device
)
# normal generation with one stopping criteria
out = model.generate(input_ids, max_length=15)
out_text = tokenizer.batch_decode(out)
expected_out = [
"They completed the challenging puzzle, revealing the hidden secrets of the world.\n",
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
]
self.assertListEqual(out_text, expected_out)
# generation should stop at "secrets" for first batch only, filling the rest with eos tokens
out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer)
out_text = tokenizer.batch_decode(out)
expected_out = [
"They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>",
"<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
"The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
]
self.assertListEqual(out_text, expected_out)
def test_constrained_beam_search_mixin_type_checks(self): def test_constrained_beam_search_mixin_type_checks(self):
# PT-only test: TF doesn't have constrained beam search # PT-only test: TF doesn't have constrained beam search
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment