"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "980211a63a2a07057a97b1eb47b7b09d7eda2bcd"
Unverified Commit 5afca00b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1724 from huggingface/fix_encode_plus

Fix encode_plus
parents 49108288 8d6b9d71
...@@ -273,7 +273,11 @@ class CommonTestCases: ...@@ -273,7 +273,11 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, add_special_tokens=False) sequence = tokenizer.encode(seq_0, add_special_tokens=False)
num_added_tokens = tokenizer.num_added_tokens() num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride) information = tokenizer.encode_plus(seq_0,
max_length=total_length - 2,
add_special_tokens=True,
stride=stride,
return_overflowing_tokens=True)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
...@@ -300,10 +304,12 @@ class CommonTestCases: ...@@ -300,10 +304,12 @@ class CommonTestCases:
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncation_strategy='only_second') stride=stride, truncation_strategy='only_second',
return_overflowing_tokens=True)
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride, add_special_tokens=True, stride=stride,
truncation_strategy='only_first') truncation_strategy='only_first',
return_overflowing_tokens=True)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
...@@ -335,7 +341,7 @@ class CommonTestCases: ...@@ -335,7 +341,7 @@ class CommonTestCases:
# Testing single inputs # Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True) encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True, return_special_tokens_mask=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
...@@ -347,7 +353,8 @@ class CommonTestCases: ...@@ -347,7 +353,8 @@ class CommonTestCases:
# Testing inputs pairs # Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1, encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1,
add_special_tokens=False) add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True) encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True,
return_special_tokens_mask=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
...@@ -359,7 +366,9 @@ class CommonTestCases: ...@@ -359,7 +366,9 @@ class CommonTestCases:
# Testing with already existing special tokens # Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'}) tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True) encoded_sequence_dict = tokenizer.encode_plus(sequence_0,
add_special_tokens=True,
return_special_tokens_mask=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True) special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
......
...@@ -750,6 +750,9 @@ class PreTrainedTokenizer(object): ...@@ -750,6 +750,9 @@ class PreTrainedTokenizer(object):
stride=0, stride=0,
truncation_strategy='longest_first', truncation_strategy='longest_first',
return_tensors=None, return_tensors=None,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs): **kwargs):
""" """
Returns a dictionary containing the encoded sequence or sequence pair and additional informations: Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...@@ -776,7 +779,30 @@ class PreTrainedTokenizer(object): ...@@ -776,7 +779,30 @@ class PreTrainedTokenizer(object):
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
Return:
A Dictionary of shape::
{
input_ids: list[int],
token_type_ids: list[int] if return_token_type_ids is True (default)
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
}
With the fields:
``input_ids``: list of token ids to be fed to a model
``token_type_ids``: list of token type ids to be fed to a model
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens.
""" """
def get_input_ids(text): def get_input_ids(text):
...@@ -798,10 +824,17 @@ class PreTrainedTokenizer(object): ...@@ -798,10 +824,17 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
return_tensors=return_tensors) return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first', return_tensors=None): truncation_strategy='longest_first',
return_tensors=None,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates It adds special tokens, truncates
...@@ -826,21 +859,27 @@ class PreTrainedTokenizer(object): ...@@ -826,21 +859,27 @@ class PreTrainedTokenizer(object):
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
Return: Return:
A Dictionary of shape:: A Dictionary of shape::
{ {
input_ids: list[int], input_ids: list[int],
overflowing_tokens: list[int] if a ``max_length`` is specified, else None token_type_ids: list[int] if return_token_type_ids is True (default)
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
} }
With the fields: With the fields:
``input_ids``: list of tokens to be fed to a model ``input_ids``: list of token ids to be fed to a model
``token_type_ids``: list of token type ids to be fed to a model
``overflowing_tokens``: list of overflowing tokens if a max length is specified. ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens. tokens and 1 specifying sequence tokens.
""" """
...@@ -849,23 +888,31 @@ class PreTrainedTokenizer(object): ...@@ -849,23 +888,31 @@ class PreTrainedTokenizer(object):
len_pair_ids = len(pair_ids) if pair else 0 len_pair_ids = len(pair_ids) if pair else 0
encoded_inputs = {} encoded_inputs = {}
# Handle max sequence length
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0) total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
if max_length and total_len > max_length: if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
num_tokens_to_remove=total_len-max_length, num_tokens_to_remove=total_len-max_length,
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
stride=stride) stride=stride)
encoded_inputs["overflowing_tokens"] = overflowing_tokens if return_overflowing_tokens:
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Handle special_tokens
if add_special_tokens: if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids) sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
else: else:
sequence = ids + pair_ids if pair else ids sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
# Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available(): if return_tensors == 'tf' and is_tf_available():
sequence = tf.constant([sequence]) sequence = tf.constant([sequence])
token_type_ids = tf.constant([token_type_ids]) token_type_ids = tf.constant([token_type_ids])
...@@ -876,12 +923,15 @@ class PreTrainedTokenizer(object): ...@@ -876,12 +923,15 @@ class PreTrainedTokenizer(object):
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors)) logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
encoded_inputs["input_ids"] = sequence encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if max_length and len(encoded_inputs["input_ids"]) > max_length: if max_length and len(encoded_inputs["input_ids"]) > max_length:
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length] encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] if return_token_type_ids:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length " logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment