Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2670b0d6
Commit
2670b0d6
authored
Dec 04, 2019
by
Michael Watkins
Committed by
Lysandre Debut
Dec 06, 2019
Browse files
Fix bug which lowercases special tokens
parent
35401fe5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
5 deletions
+18
-5
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+5
-3
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+13
-2
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
2670b0d6
...
@@ -115,8 +115,10 @@ class CommonTestCases:
...
@@ -115,8 +115,10 @@ class CommonTestCases:
def
test_added_tokens_do_lower_case
(
self
):
def
test_added_tokens_do_lower_case
(
self
):
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
True
)
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
True
)
text
=
"aaaaa bbbbbb low cccccccccdddddddd l"
special_token
=
tokenizer
.
all_special_tokens
[
0
]
text2
=
"AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
text
=
special_token
+
" aaaaa bbbbbb low cccccccccdddddddd l "
+
special_token
text2
=
special_token
+
" AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l "
+
special_token
toks0
=
tokenizer
.
tokenize
(
text
)
# toks before adding new_toks
toks0
=
tokenizer
.
tokenize
(
text
)
# toks before adding new_toks
...
@@ -141,7 +143,7 @@ class CommonTestCases:
...
@@ -141,7 +143,7 @@ class CommonTestCases:
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
# Length should still be the same
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
# Length should still be the same
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
self
.
assertNotEqual
(
toks
[
0
],
toks2
[
0
])
# But at least the first tokens should differ
self
.
assertNotEqual
(
toks
[
1
],
toks2
[
1
])
# But at least the first
non-special
tokens should differ
def
test_add_tokens_tokenizer
(
self
):
def
test_add_tokens_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
...
transformers/tokenization_utils.py
View file @
2670b0d6
...
@@ -22,6 +22,7 @@ import json
...
@@ -22,6 +22,7 @@ import json
import
six
import
six
import
copy
import
copy
import
itertools
import
itertools
import
re
from
io
import
open
from
io
import
open
from
.file_utils
import
cached_path
,
is_tf_available
,
is_torch_available
from
.file_utils
import
cached_path
,
is_tf_available
,
is_torch_available
...
@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
...
@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
to_add_tokens
=
[]
to_add_tokens
=
[]
for
token
in
new_tokens
:
for
token
in
new_tokens
:
assert
isinstance
(
token
,
str
)
or
(
six
.
PY2
and
isinstance
(
token
,
unicode
))
assert
isinstance
(
token
,
str
)
or
(
six
.
PY2
and
isinstance
(
token
,
unicode
))
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
):
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
)
and
token
not
in
self
.
all_special_tokens
:
token
=
token
.
lower
()
token
=
token
.
lower
()
if
token
!=
self
.
unk_token
and
\
if
token
!=
self
.
unk_token
and
\
self
.
convert_tokens_to_ids
(
token
)
==
self
.
convert_tokens_to_ids
(
self
.
unk_token
)
and
\
self
.
convert_tokens_to_ids
(
token
)
==
self
.
convert_tokens_to_ids
(
self
.
unk_token
)
and
\
...
@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
...
@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
Take care of added tokens.
"""
"""
def
lowercase_text
(
t
):
# convert non-special tokens to lowercase
escaped_special_toks
=
[
re
.
escape
(
s_tok
)
for
s_tok
in
self
.
all_special_tokens
]
pattern
=
r
'(^'
+
r
'|'
.
join
(
escaped_special_toks
)
+
r
')|'
+
\
r
'(.+?)'
return
re
.
sub
(
pattern
,
lambda
m
:
m
.
groups
()[
0
]
or
m
.
groups
()[
1
].
lower
(),
t
)
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
):
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
):
text
=
text
.
lower
(
)
text
=
lowercase_text
(
text
)
def
split_on_token
(
tok
,
text
):
def
split_on_token
(
tok
,
text
):
result
=
[]
result
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment