Unverified Commit d70919e6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for tokens being suffixes or part of each other. (#13918)

* Adding support for tokens being suffixes or part of each other.

* Better test name.
parent 026866df
...@@ -150,26 +150,44 @@ class Trie: ...@@ -150,26 +150,44 @@ class Trie:
# Lookahead to match longest first # Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100 # Important in case of extra_id_1 vs extra_id_100
lookahead_index = current # Here we are also actively looking for other earlier partial
end = current # matches
next_char = text[lookahead_index] if lookahead_index < len(text) else None # "[CLS]", "L", we need to match CLS even if L is special
while next_char in trie_pointer: for lookstart, looktrie_pointer in states.items():
trie_pointer = trie_pointer[next_char] if lookstart > start:
lookahead_index += 1 # This partial match is later, we can stop looking
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break break
next_char = text[lookahead_index] elif lookstart < start:
# End lookahead # This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index = current + 1
end = current + 1
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting # Storing and resetting
offsets.append(start) offsets.append(start)
offsets.append(end) offsets.append(end)
reset = True reset = True
break
elif current_char in trie_pointer: elif current_char in trie_pointer:
# The current character being looked at has a match within the trie # The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later). # update the pointer (it will be stored back into states later).
...@@ -210,6 +228,9 @@ class Trie: ...@@ -210,6 +228,9 @@ class Trie:
# item so we need to break. # item so we need to break.
break break
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting. # We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual # We need to eventually add the first part of the string and the eventual
# last part. # last part.
...@@ -217,7 +238,12 @@ class Trie: ...@@ -217,7 +238,12 @@ class Trie:
tokens = [] tokens = []
start = 0 start = 0
for end in offsets: for end in offsets:
if start == end: if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0 # This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two # we're also preventing zero-width cuts in case of two
# consecutive matches # consecutive matches
......
...@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase): ...@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase):
trie.add("TOKEN]") trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]") trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"]) self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_subtokens(self):
trie = Trie()
trie.add("A")
trie.add("P")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_suffix_tokens(self):
trie = Trie()
trie.add("AB")
trie.add("B")
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])
def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment