Unverified Commit d70919e6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for tokens being suffixes or part of each other. (#13918)

* Adding support for tokens being suffixes or part of each other.

* Better test name.
parent 026866df
......@@ -150,13 +150,30 @@ class Trie:
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Here we are also actively looking for other earlier partial
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index = current + 1
end = current + 1
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
......@@ -170,6 +187,7 @@ class Trie:
offsets.append(start)
offsets.append(end)
reset = True
break
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
......@@ -210,6 +228,9 @@ class Trie:
# item so we need to break.
break
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
......@@ -217,7 +238,12 @@ class Trie:
tokens = []
start = 0
for end in offsets:
if start == end:
if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
......
......@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase):
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_subtokens(self):
trie = Trie()
trie.add("A")
trie.add("P")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_suffix_tokens(self):
trie = Trie()
trie.add("AB")
trie.add("B")
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])
def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment