Unverified Commit 7079a99e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing 1-length special tokens cut. (#13862)

parent 7051b892
...@@ -20,6 +20,7 @@ import bisect ...@@ -20,6 +20,7 @@ import bisect
import itertools import itertools
import re import re
import unicodedata import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload from typing import Any, Dict, List, Optional, Tuple, Union, overload
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
...@@ -102,7 +103,6 @@ class Trie: ...@@ -102,7 +103,6 @@ class Trie:
>>> trie.split("[CLS] This is a extra_id_100") >>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"] ["[CLS]", " This is a ", "extra_id_100"]
""" """
# indexes are counted left of the chars index. # indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e. # "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o". # index 5 is right of the "o".
...@@ -115,7 +115,7 @@ class Trie: ...@@ -115,7 +115,7 @@ class Trie:
# If the trie contains, "blowing", and "lower" and we encounter the # If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"]. # string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts. # This is where we need to keep track of multiple possible starts.
states = {} states = OrderedDict()
# This will contain every indices where we need # This will contain every indices where we need
# to cut. # to cut.
...@@ -144,36 +144,36 @@ class Trie: ...@@ -144,36 +144,36 @@ class Trie:
# In this case, we already have partial matches (But unfinished) # In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items(): for start, trie_pointer in states.items():
if current_char in trie_pointer: if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie # The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later). # update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char] trie_pointer = trie_pointer[current_char]
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current + 1
end = current + 1
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
# Storing back the new pointer into the states. # Storing back the new pointer into the states.
# Partial matches got longer by one. # Partial matches got longer by one.
...@@ -198,6 +198,18 @@ class Trie: ...@@ -198,6 +198,18 @@ class Trie:
if current_char in self.data: if current_char in self.data:
states[current] = self.data[current_char] states[current] = self.data[current_char]
# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
# We have all the offsets now, we just need to do the actual splitting. # We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual # We need to eventually add the first part of the string and the eventual
# last part. # last part.
......
...@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase): ...@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
trie.add("extra_id_1") trie.add("extra_id_1")
trie.add("extra_id_100") trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"]) self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])
def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment