Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7079a99e
"vscode:/vscode.git/clone" did not exist on "462cd641d9b1fbe408964ffe60ee255bb94fd042"
Unverified
Commit
7079a99e
authored
Oct 05, 2021
by
Nicolas Patry
Committed by
GitHub
Oct 05, 2021
Browse files
Fixing 1-length special tokens cut. (#13862)
parent
7051b892
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
29 deletions
+53
-29
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+41
-29
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+12
-0
No files found.
src/transformers/tokenization_utils.py
View file @
7079a99e
...
@@ -20,6 +20,7 @@ import bisect
...
@@ -20,6 +20,7 @@ import bisect
import
itertools
import
itertools
import
re
import
re
import
unicodedata
import
unicodedata
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
from
.file_utils
import
PaddingStrategy
,
TensorType
,
add_end_docstrings
from
.file_utils
import
PaddingStrategy
,
TensorType
,
add_end_docstrings
...
@@ -102,7 +103,6 @@ class Trie:
...
@@ -102,7 +103,6 @@ class Trie:
>>> trie.split("[CLS] This is a extra_id_100")
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
["[CLS]", " This is a ", "extra_id_100"]
"""
"""
# indexes are counted left of the chars index.
# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
# index 5 is right of the "o".
...
@@ -115,7 +115,7 @@ class Trie:
...
@@ -115,7 +115,7 @@ class Trie:
# If the trie contains, "blowing", and "lower" and we encounter the
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
# This is where we need to keep track of multiple possible starts.
states
=
{}
states
=
OrderedDict
()
# This will contain every indices where we need
# This will contain every indices where we need
# to cut.
# to cut.
...
@@ -144,18 +144,14 @@ class Trie:
...
@@ -144,18 +144,14 @@ class Trie:
# In this case, we already have partial matches (But unfinished)
# In this case, we already have partial matches (But unfinished)
for
start
,
trie_pointer
in
states
.
items
():
for
start
,
trie_pointer
in
states
.
items
():
if
current_char
in
trie_pointer
:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer
=
trie_pointer
[
current_char
]
if
""
in
trie_pointer
:
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# This is a final match, we need to reset and
# store the results in `offsets`.
# store the results in `offsets`.
# Lookahead to match longest first
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Important in case of extra_id_1 vs extra_id_100
lookahead_index
=
current
+
1
lookahead_index
=
current
end
=
current
+
1
end
=
current
next_char
=
text
[
lookahead_index
]
if
lookahead_index
<
len
(
text
)
else
None
next_char
=
text
[
lookahead_index
]
if
lookahead_index
<
len
(
text
)
else
None
while
next_char
in
trie_pointer
:
while
next_char
in
trie_pointer
:
trie_pointer
=
trie_pointer
[
next_char
]
trie_pointer
=
trie_pointer
[
next_char
]
...
@@ -174,6 +170,10 @@ class Trie:
...
@@ -174,6 +170,10 @@ class Trie:
offsets
.
append
(
start
)
offsets
.
append
(
start
)
offsets
.
append
(
end
)
offsets
.
append
(
end
)
reset
=
True
reset
=
True
elif
current_char
in
trie_pointer
:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer
=
trie_pointer
[
current_char
]
# Storing back the new pointer into the states.
# Storing back the new pointer into the states.
# Partial matches got longer by one.
# Partial matches got longer by one.
...
@@ -198,6 +198,18 @@ class Trie:
...
@@ -198,6 +198,18 @@ class Trie:
if
current_char
in
self
.
data
:
if
current_char
in
self
.
data
:
states
[
current
]
=
self
.
data
[
current_char
]
states
[
current
]
=
self
.
data
[
current_char
]
# We have a cut at the end with states.
for
start
,
trie_pointer
in
states
.
items
():
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# store the results in `offsets`.
end
=
len
(
text
)
offsets
.
append
(
start
)
offsets
.
append
(
end
)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
# We have all the offsets now, we just need to do the actual splitting.
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# We need to eventually add the first part of the string and the eventual
# last part.
# last part.
...
...
tests/test_tokenization_common.py
View file @
7079a99e
...
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
...
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
trie
.
add
(
"extra_id_1"
)
trie
.
add
(
"extra_id_1"
)
trie
.
add
(
"extra_id_100"
)
trie
.
add
(
"extra_id_100"
)
self
.
assertEqual
(
trie
.
split
(
"[CLS] This is a extra_id_100"
),
[
"[CLS]"
,
" This is a "
,
"extra_id_100"
])
self
.
assertEqual
(
trie
.
split
(
"[CLS] This is a extra_id_100"
),
[
"[CLS]"
,
" This is a "
,
"extra_id_100"
])
def
test_trie_single
(
self
):
trie
=
Trie
()
trie
.
add
(
"A"
)
self
.
assertEqual
(
trie
.
split
(
"ABC"
),
[
"A"
,
"BC"
])
self
.
assertEqual
(
trie
.
split
(
"BCA"
),
[
"BC"
,
"A"
])
def
test_trie_final
(
self
):
trie
=
Trie
()
trie
.
add
(
"TOKEN]"
)
trie
.
add
(
"[SPECIAL_TOKEN]"
)
self
.
assertEqual
(
trie
.
split
(
"This is something [SPECIAL_TOKEN]"
),
[
"This is something "
,
"[SPECIAL_TOKEN]"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment