Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70708cca
"...resnet50_tensorflow.git" did not exist on "f2f4f2dcca937475cf2c2c148e3844af022f04f7"
Unverified
Commit
70708cca
authored
Nov 10, 2020
by
Patrick von Platen
Committed by
GitHub
Nov 10, 2020
Browse files
fix t5 token type ids (#8437)
parent
9fd1f562
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
0 deletions
+58
-0
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+22
-0
src/transformers/tokenization_t5_fast.py
src/transformers/tokenization_t5_fast.py
+22
-0
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+14
-0
No files found.
src/transformers/tokenization_t5.py
View file @
70708cca
...
@@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
else
:
else
:
return
token_ids
+
[
self
.
eos_token_id
]
return
token_ids
+
[
self
.
eos_token_id
]
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos
=
[
self
.
eos_token_id
]
if
token_ids_1
is
None
:
return
len
(
token_ids_0
+
eos
)
*
[
0
]
return
len
(
token_ids_0
+
eos
+
token_ids_1
+
eos
)
*
[
0
]
def
build_inputs_with_special_tokens
(
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
)
->
List
[
int
]:
...
...
src/transformers/tokenization_t5_fast.py
View file @
70708cca
...
@@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
...
@@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
token_ids_1
=
token_ids_1
+
[
self
.
eos_token_id
]
token_ids_1
=
token_ids_1
+
[
self
.
eos_token_id
]
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos
=
[
self
.
eos_token_id
]
if
token_ids_1
is
None
:
return
len
(
token_ids_0
+
eos
)
*
[
0
]
return
len
(
token_ids_0
+
eos
+
token_ids_1
+
eos
)
*
[
0
]
@
add_start_docstrings
(
PREPARE_SEQ2SEQ_BATCH_DOCSTRING
)
@
add_start_docstrings
(
PREPARE_SEQ2SEQ_BATCH_DOCSTRING
)
def
prepare_seq2seq_batch
(
def
prepare_seq2seq_batch
(
self
,
self
,
...
...
tests/test_tokenization_t5.py
View file @
70708cca
...
@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
def
test_token_type_ids
(
self
):
src_text_1
=
[
"A first paragraph for summarization."
]
src_text_2
=
[
"A second paragraph for summarization."
]
fast_token_type_ids
=
self
.
t5_base_tokenizer_fast
(
src_text_1
,
src_text_2
,
add_special_tokens
=
True
,
return_token_type_ids
=
True
).
token_type_ids
slow_token_type_ids
=
self
.
t5_base_tokenizer
(
src_text_1
,
src_text_2
,
add_special_tokens
=
True
,
return_token_type_ids
=
True
).
token_type_ids
self
.
assertEqual
(
slow_token_type_ids
,
fast_token_type_ids
)
self
.
assertEqual
(
len
(
slow_token_type_ids
[
0
]),
18
)
def
test_fast_and_slow_same_result
(
self
):
def
test_fast_and_slow_same_result
(
self
):
src_text
=
"<pad> Today is <unk> nice day </s>"
src_text
=
"<pad> Today is <unk> nice day </s>"
tgt_ids
=
[
0
,
1960
,
19
,
2
,
1245
,
239
,
1
]
tgt_ids
=
[
0
,
1960
,
19
,
2
,
1245
,
239
,
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment