Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e737018
Unverified
Commit
5e737018
authored
May 28, 2020
by
Anthony MOI
Committed by
GitHub
May 28, 2020
Browse files
Fix add_special_tokens on fast tokenizers (#4531)
parent
e444648a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+8
-3
tests/test_tokenization_fast.py
tests/test_tokenization_fast.py
+2
-1
No files found.
src/transformers/tokenization_utils.py
View file @
5e737018
...
@@ -2400,15 +2400,20 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
...
@@ -2400,15 +2400,20 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def
add_special_tokens
(
self
,
special_tokens_dict
:
dict
)
->
int
:
def
add_special_tokens
(
self
,
special_tokens_dict
:
dict
)
->
int
:
# Map special tokens to class attributes (self.pad_token...)
# Map special tokens to class attributes (self.pad_token...)
num_added_tokens
=
super
().
add_special_tokens
(
special_tokens_dict
)
super
().
add_special_tokens
(
special_tokens_dict
)
# If the backend tokenizer the only specificities of special tokens are that
# If the backend tokenizer the only specificities of special tokens are that
# - they will never be processed by the model, and
# - they will never be processed by the model, and
# - they will be removed while decoding.
# - they will be removed while decoding.
# But they are not mapped to special attributes in the backend so we can just
# But they are not mapped to special attributes in the backend so we can just
# send a list.
# send a list.
tokens
=
flatten
(
special_tokens_dict
.
values
())
tokens
=
[]
self
.
_tokenizer
.
add_special_tokens
(
tokens
)
for
token
in
special_tokens_dict
.
values
():
if
isinstance
(
token
,
list
):
tokens
+=
token
else
:
tokens
+=
[
token
]
num_added_tokens
=
self
.
_tokenizer
.
add_special_tokens
(
tokens
)
return
num_added_tokens
return
num_added_tokens
...
...
tests/test_tokenization_fast.py
View file @
5e737018
...
@@ -221,6 +221,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
...
@@ -221,6 +221,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self
.
assertEqual
(
len
(
tokenizer_r
),
vocab_size
+
3
)
self
.
assertEqual
(
len
(
tokenizer_r
),
vocab_size
+
3
)
self
.
assertEqual
(
tokenizer_r
.
add_special_tokens
({}),
0
)
self
.
assertEqual
(
tokenizer_r
.
add_special_tokens
({}),
0
)
self
.
assertEqual
(
tokenizer_r
.
add_special_tokens
({
"bos_token"
:
"[BOS]"
,
"eos_token"
:
"[EOS]"
}),
2
)
self
.
assertRaises
(
self
.
assertRaises
(
AssertionError
,
tokenizer_r
.
add_special_tokens
,
{
"additional_special_tokens"
:
"<testtoken1>"
}
AssertionError
,
tokenizer_r
.
add_special_tokens
,
{
"additional_special_tokens"
:
"<testtoken1>"
}
)
)
...
@@ -228,7 +229,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
...
@@ -228,7 +229,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
tokenizer_r
.
add_special_tokens
({
"additional_special_tokens"
:
[
"<testtoken3>"
,
"<testtoken4>"
]}),
2
tokenizer_r
.
add_special_tokens
({
"additional_special_tokens"
:
[
"<testtoken3>"
,
"<testtoken4>"
]}),
2
)
)
self
.
assertEqual
(
len
(
tokenizer_r
),
vocab_size
+
6
)
self
.
assertEqual
(
len
(
tokenizer_r
),
vocab_size
+
8
)
def
assert_offsets_mapping
(
self
,
tokenizer_r
):
def
assert_offsets_mapping
(
self
,
tokenizer_r
):
text
=
"Wonderful no inspiration example with subtoken"
text
=
"Wonderful no inspiration example with subtoken"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment