"...resnet50_tensorflow.git" did not exist on "8b43ab7c122a6eeb22994e329d9b6cb4bb57ca91"
Unverified Commit d70919e6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for tokens being suffixes or part of each other. (#13918)

* Adding support for tokens being suffixes or part of each other.

* Better test name.
parent 026866df
...@@ -150,13 +150,30 @@ class Trie: ...@@ -150,13 +150,30 @@ class Trie:
# Lookahead to match longest first # Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100 # Important in case of extra_id_1 vs extra_id_100
# Here we are also actively looking for other earlier partial
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index = current + 1
end = current + 1
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current lookahead_index = current
end = current end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer: while next_char in looktrie_pointer:
trie_pointer = trie_pointer[next_char] looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1 lookahead_index += 1
if "" in trie_pointer: if "" in looktrie_pointer:
start = lookstart
end = lookahead_index end = lookahead_index
skip = lookahead_index skip = lookahead_index
...@@ -170,6 +187,7 @@ class Trie: ...@@ -170,6 +187,7 @@ class Trie:
offsets.append(start) offsets.append(start)
offsets.append(end) offsets.append(end)
reset = True reset = True
break
elif current_char in trie_pointer: elif current_char in trie_pointer:
# The current character being looked at has a match within the trie # The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later). # update the pointer (it will be stored back into states later).
...@@ -210,6 +228,9 @@ class Trie: ...@@ -210,6 +228,9 @@ class Trie:
# item so we need to break. # item so we need to break.
break break
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting. # We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual # We need to eventually add the first part of the string and the eventual
# last part. # last part.
...@@ -217,7 +238,12 @@ class Trie: ...@@ -217,7 +238,12 @@ class Trie:
tokens = [] tokens = []
start = 0 start = 0
for end in offsets: for end in offsets:
if start == end: if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0 # This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two # we're also preventing zero-width cuts in case of two
# consecutive matches # consecutive matches
......
...@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase): ...@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase):
trie.add("TOKEN]") trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]") trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"]) self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_subtokens(self):
trie = Trie()
trie.add("A")
trie.add("P")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_suffix_tokens(self):
trie = Trie()
trie.add("AB")
trie.add("B")
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])
def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment