Commit 66ea76b8 authored by LysandreJik's avatar LysandreJik
Browse files

prepare_for_model and prepare_pair_for_model methods. Added an option to...

prepare_for_model and prepare_pair_for_model methods. Added an option to select which sequence will be truncated.
parent 60414f31
......@@ -237,16 +237,29 @@ class CommonTestCases:
seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded."
stride = 2
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2]
)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride)
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncate_second_sequence_first=False)
truncated_sequence = information["sequence"]
overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence
......@@ -722,7 +722,15 @@ class PreTrainedTokenizer(object):
logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
output_mask=False,
max_length=None,
stride=0,
truncate_second_sequence_first=True,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
......@@ -738,54 +746,40 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence
will be truncated.
**kwargs: passed to the `self.tokenize()` method
"""
information = {}
if text_pair is None:
n_added_tokens = self.num_added_tokens()
if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
if add_special_tokens:
information = self.prepare_for_model(sequence_tokens, max_length, stride)
else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens
information["sequence"] = sequence_tokens
if output_mask:
information["mask"] = [0] * len(sequence)
information["sequence"] = sequence
information["mask"] = [0] * len(information["sequence"])
else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
n_added_tokens = self.num_added_tokens(pair=True)
if add_special_tokens:
if max_length:
if len(first_sentence_tokens) + n_added_tokens >= max_length:
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
sequence = self.add_special_tokens_sequence_pair(
information = self.prepare_pair_for_model(
first_sentence_tokens,
second_sentence_tokens
second_sentence_tokens,
max_length,
truncate_second_sequence_first,
stride
)
if output_mask:
information["mask"] = self.create_mask_from_sequences(text, text_pair)
information["sequence"] = sequence
else:
logger.warning("No special tokens were added. The two sequences have been concatenated.")
sequence = first_sentence_tokens + second_sentence_tokens
......@@ -800,6 +794,39 @@ class PreTrainedTokenizer(object):
return information
def prepare_for_model(self, ids, max_length=None, stride=0):
information = {}
n_added_tokens = self.num_added_tokens()
if max_length:
information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
ids = ids[:max_length - n_added_tokens]
information["sequence"] = self.add_special_tokens_single_sequence(ids)
return information
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0):
f_len, s_len = len(ids_0), len(ids_1)
n_added_tokens = self.num_added_tokens(pair=True)
information = {}
if max_length:
if len(ids_0) + n_added_tokens >= max_length:
logger.warning(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
if truncate_second_sequence_first:
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
else:
information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
ids_0 = ids_0[:max_length - s_len - n_added_tokens]
sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
information["sequence"] = sequence
return information
def create_mask_from_sequences(self, sequence_0, sequence_1):
logger.warning("This tokenizer does not make use of special tokens.")
return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment