Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7fd1d42a
Unverified
Commit
7fd1d42a
authored
Nov 27, 2019
by
Thomas Wolf
Committed by
GitHub
Nov 27, 2019
Browse files
Merge pull request #1592 from watkinsm/do_lower_case
Consider do_lower_case in PreTrainedTokenizer
parents
de2696f6
21637d49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
0 deletions
+35
-0
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+30
-0
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+5
-0
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
7fd1d42a
...
@@ -110,6 +110,36 @@ class CommonTestCases:
...
@@ -110,6 +110,36 @@ class CommonTestCases:
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
def
test_added_tokens_do_lower_case
(
self
):
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
True
)
text
=
"aaaaa bbbbbb low cccccccccdddddddd l"
text2
=
"AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
toks0
=
tokenizer
.
tokenize
(
text
)
# toks before adding new_toks
new_toks
=
[
"aaaaa bbbbbb"
,
"cccccccccdddddddd"
,
'AAAAA BBBBBB'
,
'CCCCCCCCCDDDDDDDD'
]
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
2
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
# toks0 should be longer
self
.
assertListEqual
(
toks
,
toks2
)
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
False
)
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
4
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
# Length should still be the same
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
self
.
assertNotEqual
(
toks
[
0
],
toks2
[
0
])
# But at least the first tokens should differ
def
test_add_tokens_tokenizer
(
self
):
def
test_add_tokens_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
...
transformers/tokenization_utils.py
View file @
7fd1d42a
...
@@ -513,6 +513,8 @@ class PreTrainedTokenizer(object):
...
@@ -513,6 +513,8 @@ class PreTrainedTokenizer(object):
to_add_tokens
=
[]
to_add_tokens
=
[]
for
token
in
new_tokens
:
for
token
in
new_tokens
:
assert
isinstance
(
token
,
str
)
or
(
six
.
PY2
and
isinstance
(
token
,
unicode
))
assert
isinstance
(
token
,
str
)
or
(
six
.
PY2
and
isinstance
(
token
,
unicode
))
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
):
token
=
token
.
lower
()
if
token
!=
self
.
unk_token
and
\
if
token
!=
self
.
unk_token
and
\
self
.
convert_tokens_to_ids
(
token
)
==
self
.
convert_tokens_to_ids
(
self
.
unk_token
)
and
\
self
.
convert_tokens_to_ids
(
token
)
==
self
.
convert_tokens_to_ids
(
self
.
unk_token
)
and
\
token
not
in
to_add_tokens
:
token
not
in
to_add_tokens
:
...
@@ -606,6 +608,9 @@ class PreTrainedTokenizer(object):
...
@@ -606,6 +608,9 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
Take care of added tokens.
"""
"""
if
self
.
init_kwargs
.
get
(
'do_lower_case'
,
False
):
text
=
text
.
lower
()
def
split_on_token
(
tok
,
text
):
def
split_on_token
(
tok
,
text
):
result
=
[]
result
=
[]
split_text
=
text
.
split
(
tok
)
split_text
=
text
.
split
(
tok
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment