"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "559790f9e4576b4a21b3efb01711abba6b5ac1c4"
Commit 9f374c82 authored by LysandreJik's avatar LysandreJik
Browse files

`encode` and `encode_plus` handle attention masks and padding

parent 72e506b2
......@@ -335,3 +335,54 @@ class CommonTestCases:
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
def test_padding_to_max_length(self):
tokenizer = self.get_tokenizer()
sequence = "Sequence"
padding_size = 10
padding_idx = tokenizer.pad_token_id
# Check that it correctly pads when a maximum length is specified along with the padding flag set to True
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
# Check that nothing is done when a maximum length is not specified
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_length = len(padded_sequence)
assert sequence_length == padded_sequence_length
assert encoded_sequence == padded_sequence
def test_encode_plus_with_padding(self):
tokenizer = self.get_tokenizer()
sequence = "Sequence"
padding_size = 10
padding_idx = tokenizer.pad_token_id
token_type_padding_idx = tokenizer.pad_token_type_id
encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True)
input_ids = encoded_sequence['input_ids']
token_type_ids = encoded_sequence['token_type_ids']
attention_mask = encoded_sequence['attention_mask']
special_tokens_mask = encoded_sequence['special_tokens_mask']
sequence_length = len(input_ids)
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask']
padded_special_tokens_mask = padded_sequence['special_tokens_mask']
padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length
assert input_ids + [padding_idx] * padding_size == padded_input_ids
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
assert attention_mask + [0] * padding_size == padded_attention_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
\ No newline at end of file
......@@ -190,6 +190,11 @@ class PreTrainedTokenizer(object):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
return self.convert_tokens_to_ids(self.pad_token)
@property
def pad_token_type_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
return self._pad_token_type_id
@property
def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
......@@ -213,6 +218,7 @@ class PreTrainedTokenizer(object):
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._pad_token_type_id = 0
self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12)
......@@ -696,6 +702,7 @@ class PreTrainedTokenizer(object):
max_length=None,
stride=0,
truncation_strategy='longest_first',
pad_to_max_length=False,
return_tensors=None,
**kwargs):
"""
......@@ -722,6 +729,8 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
......@@ -732,6 +741,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens,
stride=stride,
truncation_strategy=truncation_strategy,
pad_to_max_length=pad_to_max_length,
return_tensors=return_tensors,
**kwargs)
......@@ -744,7 +754,12 @@ class PreTrainedTokenizer(object):
max_length=None,
stride=0,
truncation_strategy='longest_first',
pad_to_max_length=False,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
......@@ -769,9 +784,37 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_attention_mask: (optional) Set to False to avoir returning attention mask (default True)
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
**kwargs: passed to the `self.tokenize()` method
Return:
A Dictionary of shape::
{
input_ids: list[int],
token_type_ids: list[int] if return_token_type_ids is True (default)
attention_mask: list[int] if return_attention_mask is True (default)
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
}
With the fields:
``input_ids``: list of token ids to be fed to a model
``token_type_ids``: list of token type ids to be fed to a model
``attention_mask``: list of indices specifying which tokens should be attended to by the model
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens.
"""
def get_input_ids(text):
......@@ -790,13 +833,24 @@ class PreTrainedTokenizer(object):
return self.prepare_for_model(first_ids,
pair_ids=second_ids,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncation_strategy=truncation_strategy,
return_tensors=return_tensors)
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first', return_tensors=None):
truncation_strategy='longest_first',
pad_to_max_length=False,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
......@@ -819,8 +873,14 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad_to_max_length: if set to `True`, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified, no padding is done.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_attention_mask: (optional) Set to False to avoir returning attention mask (default True)
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
Return:
A Dictionary of shape::
......@@ -883,6 +943,19 @@ class PreTrainedTokenizer(object):
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
if pad_to_max_length and max_length and len(encoded_inputs["input_ids"]) < max_length:
difference = max_length - len(encoded_inputs["input_ids"])
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] += [self.pad_token_type_id] * difference
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] += [1] * difference
encoded_inputs["input_ids"] += [self.pad_token_id] * difference
elif return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
......
......@@ -74,6 +74,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
self._pad_token_type_id = 3
try:
import sentencepiece as spm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment