Commit 7c789c33 authored by LysandreJik's avatar LysandreJik
Browse files

Always truncate argument in the encode method

parent 7af07779
...@@ -232,6 +232,23 @@ class CommonTestCases: ...@@ -232,6 +232,23 @@ class CommonTestCases:
assert len(truncated_sequence) == total_length - 2 assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2]) assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_always_truncate(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
length_single_sequence = len(tokenizer.encode(seq_0))
length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True))
not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence)
truncated = tokenizer.encode(
seq_0, seq_0,
max_length=length_single_sequence,
add_special_tokens=True,
always_truncate=True
)
assert truncated == not_truncated[:length_single_sequence - length]
def test_maximum_encoding_length_pair_input(self): def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
...@@ -693,14 +693,15 @@ class PreTrainedTokenizer(object): ...@@ -693,14 +693,15 @@ class PreTrainedTokenizer(object):
raise NotImplementedError raise NotImplementedError
def encode(self, def encode(self,
text, text,
text_pair=None, text_pair=None,
add_special_tokens=False, add_special_tokens=False,
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncate_first_sequence=True,
return_tensors=None, return_tensors=None,
**kwargs): always_truncate=False,
**kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...@@ -721,6 +722,8 @@ class PreTrainedTokenizer(object): ...@@ -721,6 +722,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated. will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
...@@ -732,6 +735,7 @@ class PreTrainedTokenizer(object): ...@@ -732,6 +735,7 @@ class PreTrainedTokenizer(object):
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors, return_tensors=return_tensors,
always_truncate=always_truncate,
**kwargs) **kwargs)
return encoded_inputs["input_ids"] return encoded_inputs["input_ids"]
...@@ -744,6 +748,7 @@ class PreTrainedTokenizer(object): ...@@ -744,6 +748,7 @@ class PreTrainedTokenizer(object):
stride=0, stride=0,
truncate_first_sequence=True, truncate_first_sequence=True,
return_tensors=None, return_tensors=None,
always_truncate=False,
**kwargs): **kwargs):
""" """
Returns a dictionary containing the encoded sequence or sequence pair and additional informations: Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...@@ -764,6 +769,8 @@ class PreTrainedTokenizer(object): ...@@ -764,6 +769,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated. will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
...@@ -788,11 +795,12 @@ class PreTrainedTokenizer(object): ...@@ -788,11 +795,12 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncate_first_sequence=truncate_first_sequence,
always_truncate=always_truncate,
return_tensors=return_tensors) return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None): truncate_first_sequence=True, always_truncate=False, return_tensors=None):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates It adds special tokens, truncates
...@@ -812,6 +820,8 @@ class PreTrainedTokenizer(object): ...@@ -812,6 +820,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided, truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead. than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
...@@ -826,9 +836,14 @@ class PreTrainedTokenizer(object): ...@@ -826,9 +836,14 @@ class PreTrainedTokenizer(object):
if max_length: if max_length:
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0 n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length: if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning( if always_truncate:
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length." logger.warning(
"This pair of sequences will not be truncated.") "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids.")
else:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated.")
else: else:
if n_added_tokens + len_ids + len_pair_ids > max_length: if n_added_tokens + len_ids + len_pair_ids > max_length:
if truncate_first_sequence or not pair: if truncate_first_sequence or not pair:
...@@ -860,6 +875,10 @@ class PreTrainedTokenizer(object): ...@@ -860,6 +875,10 @@ class PreTrainedTokenizer(object):
encoded_inputs["input_ids"] = sequence encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids encoded_inputs["token_type_ids"] = token_type_ids
if always_truncate and len(encoded_inputs["input_ids"]) > max_length:
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
return encoded_inputs return encoded_inputs
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment