Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
407da12e
Unverified
Commit
407da12e
authored
Aug 17, 2020
by
Suraj Patil
Committed by
GitHub
Aug 17, 2020
Browse files
[T5Tokenizer] add prepare_seq2seq_batch method (#6122)
* tests
parent
c9564f53
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
1 deletion
+206
-1
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+128
-1
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+78
-0
No files found.
src/transformers/tokenization_t5.py
View file @
407da12e
...
@@ -19,8 +19,9 @@ import logging
...
@@ -19,8 +19,9 @@ import logging
import
os
import
os
import
re
import
re
from
shutil
import
copyfile
from
shutil
import
copyfile
from
typing
import
List
,
Optional
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
BatchEncoding
,
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -96,6 +97,8 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -96,6 +97,8 @@ class T5Tokenizer(PreTrainedTokenizer):
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"attention_mask"
]
model_input_names
=
[
"attention_mask"
]
prefix_tokens
:
List
[
int
]
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
vocab_file
,
vocab_file
,
...
@@ -206,3 +209,127 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -206,3 +209,127 @@ class T5Tokenizer(PreTrainedTokenizer):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
return
(
out_vocab_file
,)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
A T5 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos]``
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if
token_ids_1
is
None
:
return
self
.
prefix_tokens
+
token_ids_0
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
max_target_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
None
,
truncation
:
bool
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
r
"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.T5Model`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if
max_length
is
None
:
max_length
=
self
.
max_len
self
.
prefix_tokens
=
[]
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
padding
=
padding
,
truncation
=
truncation
,
**
kwargs
,
)
if
tgt_texts
is
None
:
return
model_inputs
# Process tgt_texts
if
max_target_length
is
None
:
max_target_length
=
max_length
# set prefix_tokens for target text
self
.
prefix_tokens
=
[
self
.
pad_token_id
]
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
padding
=
padding
,
max_length
=
max_target_length
,
truncation
=
truncation
,
**
kwargs
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
self
.
prefix_tokens
=
[]
return
model_inputs
tests/test_tokenization_t5.py
View file @
407da12e
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
import
os
import
os
import
unittest
import
unittest
from
transformers
import
BatchEncoding
from
transformers.testing_utils
import
_torch_available
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
...
@@ -25,6 +27,8 @@ from .test_tokenization_common import TokenizerTesterMixin
...
@@ -25,6 +27,8 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
FRAMEWORK
=
"pt"
if
_torch_available
else
"tf"
class
T5TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
class
T5TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
...
@@ -102,3 +106,77 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -102,3 +106,77 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"."
,
"."
,
],
],
)
)
def
test_prepare_seq2seq_batch
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
len
(
expected_src_tokens
),
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
9
),
batch
.
input_ids
.
shape
)
self
.
assertEqual
((
2
,
9
),
batch
.
attention_mask
.
shape
)
result
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
self
.
assertListEqual
(
expected_src_tokens
,
result
)
# Test that special tokens are reset
self
.
assertEqual
(
tokenizer
.
prefix_tokens
,
[])
def
test_empty_target_text
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
return_tensors
=
FRAMEWORK
)
# check if input_ids are returned and no decoder_input_ids
self
.
assertIn
(
"input_ids"
,
batch
)
self
.
assertIn
(
"attention_mask"
,
batch
)
self
.
assertNotIn
(
"decoder_input_ids"
,
batch
)
self
.
assertNotIn
(
"decoder_attention_mask"
,
batch
)
def
test_max_target_length
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_target_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
FRAMEWORK
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
# test None max_target_length
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
FRAMEWORK
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
def
test_outputs_not_longer_than_maxlen
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
batch
=
tokenizer
.
prepare_seq2seq_batch
(
[
"I am a small frog"
*
1000
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
512
))
def
test_eos_in_input
(
self
):
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
src_text
=
[
"A long paragraph for summrization. </s>"
]
tgt_text
=
[
"Summary of the text. </s>"
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
,
1
]
expected_tgt_tokens
=
[
0
,
20698
,
13
,
8
,
1499
,
5
,
1
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
FRAMEWORK
)
src_ids
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
tgt_ids
=
list
(
batch
.
decoder_input_ids
.
numpy
()[
0
])
self
.
assertEqual
(
expected_src_tokens
,
src_ids
)
self
.
assertEqual
(
expected_tgt_tokens
,
tgt_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment