Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70708cca
Unverified
Commit
70708cca
authored
Nov 10, 2020
by
Patrick von Platen
Committed by
GitHub
Nov 10, 2020
Browse files
fix t5 token type ids (#8437)
parent
9fd1f562
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
0 deletions
+58
-0
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+22
-0
src/transformers/tokenization_t5_fast.py
src/transformers/tokenization_t5_fast.py
+22
-0
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+14
-0
No files found.
src/transformers/tokenization_t5.py
View file @
70708cca
...
@@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
else
:
else
:
return
token_ids
+
[
self
.
eos_token_id
]
return
token_ids
+
[
self
.
eos_token_id
]
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos
=
[
self
.
eos_token_id
]
if
token_ids_1
is
None
:
return
len
(
token_ids_0
+
eos
)
*
[
0
]
return
len
(
token_ids_0
+
eos
+
token_ids_1
+
eos
)
*
[
0
]
def
build_inputs_with_special_tokens
(
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
)
->
List
[
int
]:
...
...
src/transformers/tokenization_t5_fast.py
View file @
70708cca
...
@@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
...
@@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
token_ids_1
=
token_ids_1
+
[
self
.
eos_token_id
]
token_ids_1
=
token_ids_1
+
[
self
.
eos_token_id
]
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos
=
[
self
.
eos_token_id
]
if
token_ids_1
is
None
:
return
len
(
token_ids_0
+
eos
)
*
[
0
]
return
len
(
token_ids_0
+
eos
+
token_ids_1
+
eos
)
*
[
0
]
@
add_start_docstrings
(
PREPARE_SEQ2SEQ_BATCH_DOCSTRING
)
@
add_start_docstrings
(
PREPARE_SEQ2SEQ_BATCH_DOCSTRING
)
def
prepare_seq2seq_batch
(
def
prepare_seq2seq_batch
(
self
,
self
,
...
...
tests/test_tokenization_t5.py
View file @
70708cca
...
@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
def
test_token_type_ids
(
self
):
src_text_1
=
[
"A first paragraph for summarization."
]
src_text_2
=
[
"A second paragraph for summarization."
]
fast_token_type_ids
=
self
.
t5_base_tokenizer_fast
(
src_text_1
,
src_text_2
,
add_special_tokens
=
True
,
return_token_type_ids
=
True
).
token_type_ids
slow_token_type_ids
=
self
.
t5_base_tokenizer
(
src_text_1
,
src_text_2
,
add_special_tokens
=
True
,
return_token_type_ids
=
True
).
token_type_ids
self
.
assertEqual
(
slow_token_type_ids
,
fast_token_type_ids
)
self
.
assertEqual
(
len
(
slow_token_type_ids
[
0
]),
18
)
def
test_fast_and_slow_same_result
(
self
):
def
test_fast_and_slow_same_result
(
self
):
src_text
=
"<pad> Today is <unk> nice day </s>"
src_text
=
"<pad> Today is <unk> nice day </s>"
tgt_ids
=
[
0
,
1960
,
19
,
2
,
1245
,
239
,
1
]
tgt_ids
=
[
0
,
1960
,
19
,
2
,
1245
,
239
,
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment