Unverified Commit 7079a99e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing 1-length special tokens cut. (#13862)

parent 7051b892
......@@ -20,6 +20,7 @@ import bisect
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
......@@ -102,7 +103,6 @@ class Trie:
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
"""
# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
......@@ -115,7 +115,7 @@ class Trie:
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
states = {}
states = OrderedDict()
# This will contain every indices where we need
# to cut.
......@@ -144,36 +144,36 @@ class Trie:
# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if current_char in trie_pointer:
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current + 1
end = current + 1
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
# Storing back the new pointer into the states.
# Partial matches got longer by one.
......@@ -198,6 +198,18 @@ class Trie:
if current_char in self.data:
states[current] = self.data[current_char]
# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
......
......@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])
def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment