Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7079a99e
Unverified
Commit
7079a99e
authored
Oct 05, 2021
by
Nicolas Patry
Committed by
GitHub
Oct 05, 2021
Browse files
Fixing 1-length special tokens cut. (#13862)
parent
7051b892
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
29 deletions
+53
-29
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+41
-29
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+12
-0
No files found.
src/transformers/tokenization_utils.py
View file @
7079a99e
...
@@ -20,6 +20,7 @@ import bisect
...
@@ -20,6 +20,7 @@ import bisect
import
itertools
import
itertools
import
re
import
re
import
unicodedata
import
unicodedata
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
from
.file_utils
import
PaddingStrategy
,
TensorType
,
add_end_docstrings
from
.file_utils
import
PaddingStrategy
,
TensorType
,
add_end_docstrings
...
@@ -102,7 +103,6 @@ class Trie:
...
@@ -102,7 +103,6 @@ class Trie:
>>> trie.split("[CLS] This is a extra_id_100")
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
["[CLS]", " This is a ", "extra_id_100"]
"""
"""
# indexes are counted left of the chars index.
# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
# index 5 is right of the "o".
...
@@ -115,7 +115,7 @@ class Trie:
...
@@ -115,7 +115,7 @@ class Trie:
# If the trie contains, "blowing", and "lower" and we encounter the
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
# This is where we need to keep track of multiple possible starts.
states
=
{}
states
=
OrderedDict
()
# This will contain every indices where we need
# This will contain every indices where we need
# to cut.
# to cut.
...
@@ -144,18 +144,14 @@ class Trie:
...
@@ -144,18 +144,14 @@ class Trie:
# In this case, we already have partial matches (But unfinished)
# In this case, we already have partial matches (But unfinished)
for
start
,
trie_pointer
in
states
.
items
():
for
start
,
trie_pointer
in
states
.
items
():
if
current_char
in
trie_pointer
:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer
=
trie_pointer
[
current_char
]
if
""
in
trie_pointer
:
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# This is a final match, we need to reset and
# store the results in `offsets`.
# store the results in `offsets`.
# Lookahead to match longest first
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Important in case of extra_id_1 vs extra_id_100
lookahead_index
=
current
+
1
lookahead_index
=
current
end
=
current
+
1
end
=
current
next_char
=
text
[
lookahead_index
]
if
lookahead_index
<
len
(
text
)
else
None
next_char
=
text
[
lookahead_index
]
if
lookahead_index
<
len
(
text
)
else
None
while
next_char
in
trie_pointer
:
while
next_char
in
trie_pointer
:
trie_pointer
=
trie_pointer
[
next_char
]
trie_pointer
=
trie_pointer
[
next_char
]
...
@@ -174,6 +170,10 @@ class Trie:
...
@@ -174,6 +170,10 @@ class Trie:
offsets
.
append
(
start
)
offsets
.
append
(
start
)
offsets
.
append
(
end
)
offsets
.
append
(
end
)
reset
=
True
reset
=
True
elif
current_char
in
trie_pointer
:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer
=
trie_pointer
[
current_char
]
# Storing back the new pointer into the states.
# Storing back the new pointer into the states.
# Partial matches got longer by one.
# Partial matches got longer by one.
...
@@ -198,6 +198,18 @@ class Trie:
...
@@ -198,6 +198,18 @@ class Trie:
if
current_char
in
self
.
data
:
if
current_char
in
self
.
data
:
states
[
current
]
=
self
.
data
[
current_char
]
states
[
current
]
=
self
.
data
[
current_char
]
# We have a cut at the end with states.
for
start
,
trie_pointer
in
states
.
items
():
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# store the results in `offsets`.
end
=
len
(
text
)
offsets
.
append
(
start
)
offsets
.
append
(
end
)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
# We have all the offsets now, we just need to do the actual splitting.
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# We need to eventually add the first part of the string and the eventual
# last part.
# last part.
...
...
tests/test_tokenization_common.py
View file @
7079a99e
...
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
...
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
trie
.
add
(
"extra_id_1"
)
trie
.
add
(
"extra_id_1"
)
trie
.
add
(
"extra_id_100"
)
trie
.
add
(
"extra_id_100"
)
self
.
assertEqual
(
trie
.
split
(
"[CLS] This is a extra_id_100"
),
[
"[CLS]"
,
" This is a "
,
"extra_id_100"
])
self
.
assertEqual
(
trie
.
split
(
"[CLS] This is a extra_id_100"
),
[
"[CLS]"
,
" This is a "
,
"extra_id_100"
])
def
test_trie_single
(
self
):
trie
=
Trie
()
trie
.
add
(
"A"
)
self
.
assertEqual
(
trie
.
split
(
"ABC"
),
[
"A"
,
"BC"
])
self
.
assertEqual
(
trie
.
split
(
"BCA"
),
[
"BC"
,
"A"
])
def
test_trie_final
(
self
):
trie
=
Trie
()
trie
.
add
(
"TOKEN]"
)
trie
.
add
(
"[SPECIAL_TOKEN]"
)
self
.
assertEqual
(
trie
.
split
(
"This is something [SPECIAL_TOKEN]"
),
[
"This is something "
,
"[SPECIAL_TOKEN]"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment