Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0efbb6e9
Unverified
Commit
0efbb6e9
authored
Sep 14, 2022
by
SaulLu
Committed by
GitHub
Sep 14, 2022
Browse files
fix GPT2 token's `special_tokens_mask` when used with `add_bos_token=True` (#19036)
parent
0e245480
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
0 deletions
+57
-0
src/transformers/models/gpt2/tokenization_gpt2.py
src/transformers/models/gpt2/tokenization_gpt2.py
+32
-0
tests/models/gpt2/test_tokenization_gpt2.py
tests/models/gpt2/test_tokenization_gpt2.py
+25
-0
No files found.
src/transformers/models/gpt2/tokenization_gpt2.py
View file @
0efbb6e9
...
...
@@ -261,6 +261,38 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return
output
+
bos_token_ids
+
token_ids_1
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
if
not
self
.
add_bos_token
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
False
)
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
+
([
0
]
*
len
(
token_ids_1
))
def
_tokenize
(
self
,
text
):
"""Tokenize a string."""
bpe_tokens
=
[]
...
...
tests/models/gpt2/test_tokenization_gpt2.py
View file @
0efbb6e9
...
...
@@ -250,3 +250,28 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# tokenizer has no padding token
def
test_padding_different_model_input_name
(
self
):
pass
def
test_special_tokens_mask_input_pairs_and_bos_token
(
self
):
# TODO: change to self.get_tokenizers() when the fast version is implemented
tokenizers
=
[
self
.
get_tokenizer
(
do_lower_case
=
False
,
add_bos_token
=
True
)]
for
tokenizer
in
tokenizers
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
sequence_0
=
"Encode this."
sequence_1
=
"This one too please."
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
encoded_sequence
+=
tokenizer
.
encode
(
sequence_1
,
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
sequence_1
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
,
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
filtered_sequence
=
[
(
x
if
not
special_tokens_mask
[
i
]
else
None
)
for
i
,
x
in
enumerate
(
encoded_sequence_w_special
)
]
filtered_sequence
=
[
x
for
x
in
filtered_sequence
if
x
is
not
None
]
self
.
assertEqual
(
encoded_sequence
,
filtered_sequence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment