Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
407da12e
Unverified
Commit
407da12e
authored
Aug 17, 2020
by
Suraj Patil
Committed by
GitHub
Aug 17, 2020
Browse files
[T5Tokenizer] add prepare_seq2seq_batch method (#6122)
* tests
parent
c9564f53
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
1 deletion
+206
-1
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+128
-1
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+78
-0
No files found.
src/transformers/tokenization_t5.py
View file @
407da12e
...
...
@@ -19,8 +19,9 @@ import logging
import
os
import
re
from
shutil
import
copyfile
from
typing
import
List
,
Optional
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
BatchEncoding
,
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -96,6 +97,8 @@ class T5Tokenizer(PreTrainedTokenizer):
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"attention_mask"
]
prefix_tokens
:
List
[
int
]
=
[]
def
__init__
(
self
,
vocab_file
,
...
...
@@ -206,3 +209,127 @@ class T5Tokenizer(PreTrainedTokenizer):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
A T5 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos]``
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if
token_ids_1
is
None
:
return
self
.
prefix_tokens
+
token_ids_0
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
max_target_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
None
,
truncation
:
bool
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
r
"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.T5Model`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if
max_length
is
None
:
max_length
=
self
.
max_len
self
.
prefix_tokens
=
[]
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
padding
=
padding
,
truncation
=
truncation
,
**
kwargs
,
)
if
tgt_texts
is
None
:
return
model_inputs
# Process tgt_texts
if
max_target_length
is
None
:
max_target_length
=
max_length
# set prefix_tokens for target text
self
.
prefix_tokens
=
[
self
.
pad_token_id
]
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
padding
=
padding
,
max_length
=
max_target_length
,
truncation
=
truncation
,
**
kwargs
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
self
.
prefix_tokens
=
[]
return
model_inputs
tests/test_tokenization_t5.py
View file @
407da12e
...
...
@@ -17,6 +17,8 @@
import
os
import
unittest
from
transformers
import
BatchEncoding
from
transformers.testing_utils
import
_torch_available
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
...
...
@@ -25,6 +27,8 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
FRAMEWORK
=
"pt"
if
_torch_available
else
"tf"
class
T5TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -102,3 +106,77 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"."
,
],
)
def
test_prepare_seq2seq_batch
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
len
(
expected_src_tokens
),
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
9
),
batch
.
input_ids
.
shape
)
self
.
assertEqual
((
2
,
9
),
batch
.
attention_mask
.
shape
)
result
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
self
.
assertListEqual
(
expected_src_tokens
,
result
)
# Test that special tokens are reset
self
.
assertEqual
(
tokenizer
.
prefix_tokens
,
[])
def
test_empty_target_text
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
return_tensors
=
FRAMEWORK
)
# check if input_ids are returned and no decoder_input_ids
self
.
assertIn
(
"input_ids"
,
batch
)
self
.
assertIn
(
"attention_mask"
,
batch
)
self
.
assertNotIn
(
"decoder_input_ids"
,
batch
)
self
.
assertNotIn
(
"decoder_attention_mask"
,
batch
)
def
test_max_target_length
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_target_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
FRAMEWORK
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
# test None max_target_length
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
FRAMEWORK
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
def
test_outputs_not_longer_than_maxlen
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
batch
=
tokenizer
.
prepare_seq2seq_batch
(
[
"I am a small frog"
*
1000
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
512
))
def
test_eos_in_input
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization. </s>"
]
tgt_text
=
[
"Summary of the text. </s>"
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
,
1
]
expected_tgt_tokens
=
[
0
,
20698
,
13
,
8
,
1499
,
5
,
1
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
FRAMEWORK
)
src_ids
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
tgt_ids
=
list
(
batch
.
decoder_input_ids
.
numpy
()[
0
])
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment