Unverified Commit 3dd538c4 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[Tentative] Moving slow tokenizer to the Trie world. (#13220)



* Moving slow tokenizer to the Trie world.

* Adding more docstrings to the Trie.

* Fixing doctest (incompatible wiht our format? )

* Update src/transformers/tokenization_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding a lot more comment into the internals of this algorithm.

* Cleaner doc.

* Fixing the namings.

* Update src/transformers/tokenization_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* quality.

* Fixing longest first match.

* Small improvements to cuts + more test + canine resistant test.

* Fixing fast test.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent b8385d8a
...@@ -148,6 +148,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -148,6 +148,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
if len(token) > 1: if len(token) > 1:
self.unique_no_split_tokens.append(token) self.unique_no_split_tokens.append(token)
self._create_trie(self.unique_no_split_tokens)
@property @property
def word_delimiter_token(self) -> str: def word_delimiter_token(self) -> str:
""" """
...@@ -330,6 +332,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -330,6 +332,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
self._additional_special_tokens.append(AddedToken(token)) self._additional_special_tokens.append(AddedToken(token))
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token) _insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
self._create_trie(self.unique_no_split_tokens)
return len(tokens_to_add) return len(tokens_to_add)
......
...@@ -49,6 +49,173 @@ ADDED_TOKENS_FILE = "added_tokens.json" ...@@ -49,6 +49,173 @@ ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
class Trie:
"""
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
Loose reference https://en.wikipedia.org/wiki/Trie
"""
def __init__(self):
self.data = {}
def add(self, word: str):
"""
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
The special key `""` is used to represent termination.
This function is idempotent, adding twice the same word will leave the trie unchanged
Example::
>>> trie = Trie()
>>> trie.add("Hello 友達")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
>>> trie.add("Hello")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
"""
if not word:
# Prevent empty string
return
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref = ref[char]
ref[""] = 1
def split(self, text: str) -> List[str]:
"""
Will look for the words added to the trie within `text`. Output is the original string splitted along the
boundaries of the words found.
This trie will match the longest possible word first !
Example::
>>> trie = Trie()
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS] This is a extra_id_100"]
>>> trie.add("[CLS]")
>>> trie.add("extra_id_1")
>>> trie.add("extra_id_100")
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
"""
# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
# States are going to capture every possible start (indexes as above)
# as keys, and have as values, a pointer to the position in the trie
# where we're at. This is a partial match for now.
# This enables to keep track of multiple matches while we're iterating
# the string
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
states = {}
# This will contain every indices where we need
# to cut.
# We force to cut at offset 0 and len(text) (added later)
offsets = [0]
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip = None
# Main loop, Giving this algorithm O(n) complexity
for current, current_char in enumerate(text):
if skip and current < skip:
# Prevents the lookahead for matching twice
# like extra_id_100 and id_100
continue
# This will track every state
# that stop matching, we need to stop tracking them.
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
# fail on "b", we need to remove 0 from the valid states.
to_remove = set()
# Whenever we found a match, we need to drop everything
# this is a greedy algorithm, it will match on the first found token
reset = False
# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current + 1
end = current + 1
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
# Storing back the new pointer into the states.
# Partial matches got longer by one.
states[start] = trie_pointer
else:
# The new character has not match in the trie, we need
# to stop keeping track of this partial match.
# We can't do it directly within the loop because of how
# python iteration works
to_remove.add(start)
# Either clearing the full start (we found a real match)
# Or clearing only the partial matches that didn't work.
if reset:
states = {}
else:
for start in to_remove:
del states[start]
# If this character is a starting character within the trie
# start keeping track of this partial match.
if current_char in self.data:
states[current] = self.data[current_char]
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
offsets.append(len(text))
tokens = []
start = 0
for end in offsets:
if start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens.append(text[start:end])
start = end
return tokens
def _is_whitespace(char): def _is_whitespace(char):
"""Checks whether `char` is a whitespace character.""" """Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them # \t, \n, and \r are technically control characters but we treat them
...@@ -135,6 +302,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -135,6 +302,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self.added_tokens_encoder: Dict[str, int] = {} self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {} self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = [] self.unique_no_split_tokens: List[str] = []
self.tokens_trie = Trie()
self._decode_use_source_tokenizer = False self._decode_use_source_tokenizer = False
...@@ -223,9 +391,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -223,9 +391,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0]) _insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
else: else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add))) self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
self._create_trie(self.unique_no_split_tokens)
return len(tokens_to_add) return len(tokens_to_add)
def _create_trie(self, unique_no_split_tokens):
trie = Trie()
for token in unique_no_split_tokens:
if hasattr(self, "do_lower_case") and self.do_lower_case and token not in self.all_special_tokens:
trie.add(token.lower())
else:
trie.add(token)
self.tokens_trie = trie
def num_special_tokens_to_add(self, pair: bool = False) -> int: def num_special_tokens_to_add(self, pair: bool = False) -> int:
""" """
Returns the number of added tokens when encoding a sequence with special tokens. Returns the number of added tokens when encoding a sequence with special tokens.
...@@ -279,87 +457,39 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -279,87 +457,39 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
def split_on_token(tok, text): no_split_token = set(self.unique_no_split_tokens)
result = [] tokens = self.tokens_trie.split(text)
tok_extended = all_special_tokens_extended.get(tok, None) # ["This is something", "<special_token_1>", " else"]
split_text = text.split(tok) for i, token in enumerate(tokens):
full_word = "" if token in no_split_token:
for i, sub_text in enumerate(split_text): tok_extended = all_special_tokens_extended.get(token, None)
# AddedToken can control whitespace stripping around them. left = tokens[i - 1] if i > 0 else None
# We use them for GPT2 and Roberta to have different behavior depending on the special token right = tokens[i + 1] if i < len(tokens) - 1 else None
# Cf. https://github.com/huggingface/transformers/pull/2778
# and https://github.com/huggingface/transformers/issues/3788
if isinstance(tok_extended, AddedToken): if isinstance(tok_extended, AddedToken):
if tok_extended.single_word: if tok_extended.rstrip and right:
# Try to avoid splitting on token
if (
i < len(split_text) - 1
and not _is_end_of_word(sub_text)
and not _is_start_of_word(split_text[i + 1])
):
# Don't extract the special token
full_word += sub_text + tok
elif full_word:
full_word += sub_text
result.append(full_word)
full_word = ""
continue
# Strip white spaces on the right
if tok_extended.rstrip and i > 0:
# A bit counter-intuitive but we strip the left of the string # A bit counter-intuitive but we strip the left of the string
# since tok_extended.rstrip means the special token is eating all white spaces on its right # since tok_extended.rstrip means the special token is eating all white spaces on its right
sub_text = sub_text.lstrip() tokens[i + 1] = right.lstrip()
# Strip white spaces on the left # Strip white spaces on the left
if tok_extended.lstrip and i < len(split_text) - 1: if tok_extended.lstrip and left:
sub_text = sub_text.rstrip() # Opposite here tokens[i - 1] = left.rstrip() # Opposite here
else: else:
# We strip left and right by default # We strip left and right by default
if i < len(split_text) - 1: if right:
sub_text = sub_text.rstrip() tokens[i + 1] = right.lstrip()
if i > 0: if left:
sub_text = sub_text.lstrip() tokens[i - 1] = left.rstrip()
# ["This is something", "<special_token_1>", "else"]
if i == 0 and not sub_text: tokenized_text = []
result.append(tok) for token in tokens:
elif i == len(split_text) - 1: # Need to skip eventual empty (fully stripped) tokens
if sub_text: if not token:
result.append(sub_text) continue
else: if token in no_split_token:
pass tokenized_text.append(token)
else: else:
if sub_text: tokenized_text.extend(self._tokenize(token))
result.append(sub_text) # ["This", " is", " something", "<special_token_1>", "else"]
result.append(tok)
return result
def split_on_tokens(tok_list, text):
if not text.strip():
return []
if not tok_list:
return self._tokenize(text)
tokenized_text = []
text_list = [text]
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.unique_no_split_tokens:
tokenized_text.extend(split_on_token(tok, sub_text))
else:
tokenized_text.append(sub_text)
text_list = tokenized_text
return list(
itertools.chain.from_iterable(
(
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
for token in tokenized_text
)
)
)
no_split_token = self.unique_no_split_tokens
tokenized_text = split_on_tokens(no_split_token, text)
return tokenized_text return tokenized_text
def _tokenize(self, text, **kwargs): def _tokenize(self, text, **kwargs):
......
...@@ -55,7 +55,7 @@ from transformers.testing_utils import ( ...@@ -55,7 +55,7 @@ from transformers.testing_utils import (
require_torch, require_torch,
slow, slow,
) )
from transformers.tokenization_utils import AddedToken from transformers.tokenization_utils import AddedToken, Trie
if is_torch_available(): if is_torch_available():
...@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin: ...@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin:
encoded_sequences_batch_padded_2[key], encoded_sequences_batch_padded_2[key],
) )
@require_tokenizers
def test_added_token_are_matched_longest_first(self):
if not self.test_slow_tokenizer:
self.skipTest("This test is only for slow tokenizers")
return
tokenizers = self.get_tokenizers(fast=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
try:
tokenizer.add_tokens([AddedToken("extra_id_1")])
tokenizer.add_tokens([AddedToken("extra_id_100")])
except Exception:
# Canine cannot add tokens which are not codepoints
self.skipTest("Cannot add those Added tokens")
# XXX: This used to split on `extra_id_1` first we're matching
# longest first now.
tokens = tokenizer.tokenize("This is some extra_id_100")
self.assertIn("extra_id_100", tokens)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.add_tokens([AddedToken("extra_id_100")])
tokenizer.add_tokens([AddedToken("extra_id_1")])
tokens = tokenizer.tokenize("This is some extra_id_100")
self.assertIn("extra_id_100", tokens)
@require_tokenizers @require_tokenizers
def test_added_token_serializable(self): def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
...@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase): ...@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
class TrieTest(unittest.TestCase):
def test_trie(self):
trie = Trie()
trie.add("Hello 友達")
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
trie.add("Hello")
trie.data
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
def test_trie_split(self):
trie = Trie()
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
trie.add("[CLS]")
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment