Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0efbb6e9
Unverified
Commit
0efbb6e9
authored
Sep 14, 2022
by
SaulLu
Committed by
GitHub
Sep 14, 2022
Browse files
fix GPT2 token's `special_tokens_mask` when used with `add_bos_token=True` (#19036)
parent
0e245480
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
0 deletions
+57
-0
src/transformers/models/gpt2/tokenization_gpt2.py
src/transformers/models/gpt2/tokenization_gpt2.py
+32
-0
tests/models/gpt2/test_tokenization_gpt2.py
tests/models/gpt2/test_tokenization_gpt2.py
+25
-0
No files found.
src/transformers/models/gpt2/tokenization_gpt2.py
View file @
0efbb6e9
...
...
@@ -261,6 +261,38 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return
output
+
bos_token_ids
+
token_ids_1
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
if
not
self
.
add_bos_token
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
False
)
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
+
([
0
]
*
len
(
token_ids_1
))
def
_tokenize
(
self
,
text
):
"""Tokenize a string."""
bpe_tokens
=
[]
...
...
tests/models/gpt2/test_tokenization_gpt2.py
View file @
0efbb6e9
...
...
@@ -250,3 +250,28 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# tokenizer has no padding token
def
test_padding_different_model_input_name
(
self
):
pass
def
test_special_tokens_mask_input_pairs_and_bos_token
(
self
):
# TODO: change to self.get_tokenizers() when the fast version is implemented
tokenizers
=
[
self
.
get_tokenizer
(
do_lower_case
=
False
,
add_bos_token
=
True
)]
for
tokenizer
in
tokenizers
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
sequence_0
=
"Encode this."
sequence_1
=
"This one too please."
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
encoded_sequence
+=
tokenizer
.
encode
(
sequence_1
,
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
sequence_1
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
,
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
filtered_sequence
=
[
(
x
if
not
special_tokens_mask
[
i
]
else
None
)
for
i
,
x
in
enumerate
(
encoded_sequence_w_special
)
]
filtered_sequence
=
[
x
for
x
in
filtered_sequence
if
x
is
not
None
]
self
.
assertEqual
(
encoded_sequence
,
filtered_sequence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment